Spaces:
Runtime error
Runtime error
Duplicate from songweig/rich-text-to-image
Browse filesCo-authored-by: Songwei Ge <[email protected]>
- .gitattributes +34 -0
- .gitignore +6 -0
- README.md +13 -0
- app.py +514 -0
- app_sd.py +557 -0
- models/attention.py +391 -0
- models/attention_processor.py +1687 -0
- models/dual_transformer_2d.py +151 -0
- models/region_diffusion.py +521 -0
- models/region_diffusion_xl.py +1143 -0
- models/resnet.py +882 -0
- models/transformer_2d.py +341 -0
- models/unet_2d_blocks.py +0 -0
- models/unet_2d_condition.py +983 -0
- requirements.txt +11 -0
- rich-text-to-json-iframe.html +341 -0
- rich-text-to-json.js +349 -0
- share_btn.py +116 -0
- utils/.DS_Store +0 -0
- utils/attention_utils.py +724 -0
- utils/richtext_utils.py +234 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
*.png
|
5 |
+
*.jpg
|
6 |
+
gradio_cached_examples/
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Rich Text To Image
|
3 |
+
emoji: π
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: songweig/rich-text-to-image
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
from models.region_diffusion_xl import RegionDiffusionXL
|
12 |
+
from utils.attention_utils import get_token_maps
|
13 |
+
from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
|
14 |
+
get_attention_control_input, get_gradient_guidance_input
|
15 |
+
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from PIL import Image, ImageOps
|
19 |
+
from share_btn import community_icon_html, loading_icon_html, share_js, css
|
20 |
+
|
21 |
+
|
22 |
+
help_text = """
|
23 |
+
If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
|
24 |
+
1. If you format only a portion of a word rather than the complete word, an error may occur.
|
25 |
+
2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
|
26 |
+
3. Consider using a different seed.
|
27 |
+
"""
|
28 |
+
|
29 |
+
|
30 |
+
canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
|
31 |
+
get_js_data = """
|
32 |
+
async (text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, rich_text_input, height, width, steps, guidance_weights) => {
|
33 |
+
const richEl = document.getElementById("rich-text-root");
|
34 |
+
const data = richEl? richEl.contentDocument.body._data : {};
|
35 |
+
return [text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, JSON.stringify(data), height, width, steps, guidance_weights];
|
36 |
+
}
|
37 |
+
"""
|
38 |
+
set_js_data = """
|
39 |
+
async (text_input) => {
|
40 |
+
const richEl = document.getElementById("rich-text-root");
|
41 |
+
const data = text_input ? JSON.parse(text_input) : null;
|
42 |
+
if (richEl && data) richEl.contentDocument.body.setQuillContents(data);
|
43 |
+
}
|
44 |
+
"""
|
45 |
+
|
46 |
+
get_window_url_params = """
|
47 |
+
async (url_params) => {
|
48 |
+
const params = new URLSearchParams(window.location.search);
|
49 |
+
url_params = Object.fromEntries(params);
|
50 |
+
return [url_params];
|
51 |
+
}
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
def load_url_params(url_params):
|
56 |
+
if 'prompt' in url_params:
|
57 |
+
return gr.update(visible=True), url_params
|
58 |
+
else:
|
59 |
+
return gr.update(visible=False), url_params
|
60 |
+
|
61 |
+
|
62 |
+
def main():
|
63 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
64 |
+
model = RegionDiffusionXL()
|
65 |
+
|
66 |
+
def generate(
|
67 |
+
text_input: str,
|
68 |
+
negative_text: str,
|
69 |
+
num_segments: int,
|
70 |
+
segment_threshold: float,
|
71 |
+
inject_interval: float,
|
72 |
+
inject_background: float,
|
73 |
+
seed: int,
|
74 |
+
color_guidance_weight: float,
|
75 |
+
rich_text_input: str,
|
76 |
+
height: int,
|
77 |
+
width: int,
|
78 |
+
steps: int,
|
79 |
+
guidance_weight: float,
|
80 |
+
):
|
81 |
+
run_dir = 'results/'
|
82 |
+
os.makedirs(run_dir, exist_ok=True)
|
83 |
+
# Load region diffusion model.
|
84 |
+
height = int(height) if height else 1024
|
85 |
+
width = int(width) if width else 1024
|
86 |
+
steps = 41 if not steps else steps
|
87 |
+
guidance_weight = 8.5 if not guidance_weight else guidance_weight
|
88 |
+
text_input = rich_text_input if rich_text_input != '' and rich_text_input != None else text_input
|
89 |
+
print('text_input', text_input, width, height, steps, guidance_weight, num_segments, segment_threshold, inject_interval, inject_background, color_guidance_weight, negative_text)
|
90 |
+
if (text_input == '' or rich_text_input == ''):
|
91 |
+
raise gr.Error("Please enter some text.")
|
92 |
+
# parse json to span attributes
|
93 |
+
base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
|
94 |
+
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
|
95 |
+
json.loads(text_input))
|
96 |
+
|
97 |
+
# create control input for region diffusion
|
98 |
+
region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
|
99 |
+
model, base_text_prompt, style_text_prompts, footnote_text_prompts,
|
100 |
+
footnote_target_tokens, color_text_prompts, color_names)
|
101 |
+
|
102 |
+
# create control input for cross attention
|
103 |
+
text_format_dict = get_attention_control_input(
|
104 |
+
model, base_tokens, size_text_prompts_and_sizes)
|
105 |
+
|
106 |
+
# create control input for region guidance
|
107 |
+
text_format_dict, color_target_token_ids = get_gradient_guidance_input(
|
108 |
+
model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, color_guidance_weight=color_guidance_weight)
|
109 |
+
|
110 |
+
seed_everything(seed)
|
111 |
+
|
112 |
+
# get token maps from plain text to image generation.
|
113 |
+
begin_time = time.time()
|
114 |
+
if model.selfattn_maps is None and model.crossattn_maps is None:
|
115 |
+
model.remove_tokenmap_hooks()
|
116 |
+
model.register_tokenmap_hooks()
|
117 |
+
else:
|
118 |
+
model.remove_tokenmap_hooks()
|
119 |
+
model.remove_tokenmap_hooks()
|
120 |
+
plain_img = model.sample([base_text_prompt], negative_prompt=[negative_text],
|
121 |
+
height=height, width=width, num_inference_steps=steps,
|
122 |
+
guidance_scale=guidance_weight, run_rich_text=False)
|
123 |
+
print('time lapses to get attention maps: %.4f' %
|
124 |
+
(time.time()-begin_time))
|
125 |
+
seed_everything(seed)
|
126 |
+
color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
|
127 |
+
1024//8, 1024//8, color_target_token_ids[:-1], seed,
|
128 |
+
base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
|
129 |
+
return_vis=True)
|
130 |
+
seed_everything(seed)
|
131 |
+
model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
|
132 |
+
1024//8, 1024//8, region_target_token_ids[:-1], seed,
|
133 |
+
base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
|
134 |
+
return_vis=True)
|
135 |
+
color_obj_atten_all = torch.zeros_like(color_obj_masks[-1])
|
136 |
+
for obj_mask in color_obj_masks[:-1]:
|
137 |
+
color_obj_atten_all += obj_mask
|
138 |
+
color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
|
139 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
140 |
+
antialias=True)
|
141 |
+
for color_obj_mask in color_obj_masks]
|
142 |
+
text_format_dict['color_obj_atten'] = color_obj_masks
|
143 |
+
text_format_dict['color_obj_atten_all'] = color_obj_atten_all
|
144 |
+
model.remove_tokenmap_hooks()
|
145 |
+
|
146 |
+
# generate image from rich text
|
147 |
+
begin_time = time.time()
|
148 |
+
seed_everything(seed)
|
149 |
+
rich_img = model.sample(region_text_prompts, negative_prompt=[negative_text],
|
150 |
+
height=height, width=width, num_inference_steps=steps,
|
151 |
+
guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
|
152 |
+
text_format_dict=text_format_dict, inject_selfattn=inject_interval,
|
153 |
+
inject_background=inject_background, run_rich_text=True)
|
154 |
+
print('time lapses to generate image from rich text: %.4f' %
|
155 |
+
(time.time()-begin_time))
|
156 |
+
return [plain_img.images[0], rich_img.images[0], segments_vis, token_maps]
|
157 |
+
|
158 |
+
with gr.Blocks(css=css) as demo:
|
159 |
+
url_params = gr.JSON({}, visible=False, label="URL Params")
|
160 |
+
gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
|
161 |
+
<p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
|
162 |
+
<p> UMD, Adobe, CMU <p/>
|
163 |
+
<p> ICCV, 2023 <p/>
|
164 |
+
<p> <a href="https://huggingface.co/spaces/songweig/rich-text-to-image?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="display:inline;"alt="Duplicate Space"></a> | <a href="https://rich-text-to-image.github.io">[Website]</a> | <a href="https://github.com/SongweiGe/rich-text-to-image">[Code]</a> | <a href="https://arxiv.org/abs/2304.06720">[Paper]</a><p/>
|
165 |
+
<p> Our method is now using Stable Diffusion XL. For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""")
|
166 |
+
with gr.Row():
|
167 |
+
with gr.Column():
|
168 |
+
rich_text_el = gr.HTML(canvas_html, elem_id="canvas_html")
|
169 |
+
rich_text_input = gr.Textbox(value="", visible=False)
|
170 |
+
text_input = gr.Textbox(
|
171 |
+
label='Rich-text JSON Input',
|
172 |
+
visible=False,
|
173 |
+
max_lines=1,
|
174 |
+
placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'',
|
175 |
+
elem_id="text_input"
|
176 |
+
)
|
177 |
+
negative_prompt = gr.Textbox(
|
178 |
+
label='Negative Prompt',
|
179 |
+
max_lines=1,
|
180 |
+
placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
|
181 |
+
elem_id="negative_prompt"
|
182 |
+
)
|
183 |
+
segment_threshold = gr.Slider(label='Token map threshold',
|
184 |
+
info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
|
185 |
+
minimum=0,
|
186 |
+
maximum=1,
|
187 |
+
step=0.01,
|
188 |
+
value=0.25)
|
189 |
+
inject_interval = gr.Slider(label='Detail preservation',
|
190 |
+
info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
|
191 |
+
minimum=0,
|
192 |
+
maximum=1,
|
193 |
+
step=0.01,
|
194 |
+
value=0.)
|
195 |
+
inject_background = gr.Slider(label='Unformatted token preservation',
|
196 |
+
info='(To affect less the tokens without any rich-text attributes, increase this.)',
|
197 |
+
minimum=0,
|
198 |
+
maximum=1,
|
199 |
+
step=0.01,
|
200 |
+
value=0.3)
|
201 |
+
color_guidance_weight = gr.Slider(label='Color weight',
|
202 |
+
info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
|
203 |
+
minimum=0,
|
204 |
+
maximum=2,
|
205 |
+
step=0.1,
|
206 |
+
value=0.5)
|
207 |
+
num_segments = gr.Slider(label='Number of segments',
|
208 |
+
minimum=2,
|
209 |
+
maximum=20,
|
210 |
+
step=1,
|
211 |
+
value=9)
|
212 |
+
seed = gr.Slider(label='Seed',
|
213 |
+
minimum=0,
|
214 |
+
maximum=100000,
|
215 |
+
step=1,
|
216 |
+
value=6,
|
217 |
+
elem_id="seed"
|
218 |
+
)
|
219 |
+
with gr.Accordion('Other Parameters', open=False):
|
220 |
+
steps = gr.Slider(label='Number of Steps',
|
221 |
+
minimum=0,
|
222 |
+
maximum=500,
|
223 |
+
step=1,
|
224 |
+
value=41)
|
225 |
+
guidance_weight = gr.Slider(label='CFG weight',
|
226 |
+
minimum=0,
|
227 |
+
maximum=50,
|
228 |
+
step=0.1,
|
229 |
+
value=8.5)
|
230 |
+
width = gr.Dropdown(choices=[1024],
|
231 |
+
value=1024,
|
232 |
+
label='Width',
|
233 |
+
visible=True)
|
234 |
+
height = gr.Dropdown(choices=[1024],
|
235 |
+
value=1024,
|
236 |
+
label='height',
|
237 |
+
visible=True)
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
with gr.Column(scale=1, min_width=100):
|
241 |
+
generate_button = gr.Button("Generate")
|
242 |
+
load_params_button = gr.Button(
|
243 |
+
"Load from URL Params", visible=True)
|
244 |
+
with gr.Column():
|
245 |
+
richtext_result = gr.Image(
|
246 |
+
label='Rich-text', elem_id="rich-text-image")
|
247 |
+
richtext_result.style(height=784)
|
248 |
+
with gr.Row():
|
249 |
+
plaintext_result = gr.Image(
|
250 |
+
label='Plain-text', elem_id="plain-text-image")
|
251 |
+
segments = gr.Image(label='Segmentation')
|
252 |
+
with gr.Row():
|
253 |
+
token_map = gr.Image(label='Token Maps')
|
254 |
+
with gr.Row(visible=False) as share_row:
|
255 |
+
with gr.Group(elem_id="share-btn-container"):
|
256 |
+
community_icon = gr.HTML(community_icon_html)
|
257 |
+
loading_icon = gr.HTML(loading_icon_html)
|
258 |
+
share_button = gr.Button(
|
259 |
+
"Share to community", elem_id="share-btn")
|
260 |
+
share_button.click(None, [], [], _js=share_js)
|
261 |
+
# with gr.Row():
|
262 |
+
# gr.Markdown(help_text)
|
263 |
+
|
264 |
+
with gr.Row():
|
265 |
+
footnote_examples = [
|
266 |
+
[
|
267 |
+
'{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
|
268 |
+
'',
|
269 |
+
9,
|
270 |
+
0.3,
|
271 |
+
0.3,
|
272 |
+
0.5,
|
273 |
+
3,
|
274 |
+
0,
|
275 |
+
None,
|
276 |
+
],
|
277 |
+
[
|
278 |
+
'{"ops":[{"insert":"A cozy "},{"attributes":{"link":"A charming wooden cabin with Christmas decoration, warm light coming out from the windows."},"insert":"cabin"},{"insert":" nestled in a "},{"attributes":{"link":"Towering evergreen trees covered in a thick layer of pristine snow."},"insert":"snowy forest"},{"insert":", and a "},{"attributes":{"link":"A cute snowman wearing a carrot nose, coal eyes, and a colorful scarf, welcoming visitors with a cheerful vibe."},"insert":"snowman"},{"insert":" stands in the yard."}]}',
|
279 |
+
'',
|
280 |
+
12,
|
281 |
+
0.4,
|
282 |
+
0.3,
|
283 |
+
0.5,
|
284 |
+
3,
|
285 |
+
0,
|
286 |
+
None,
|
287 |
+
],
|
288 |
+
[
|
289 |
+
'{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
|
290 |
+
'',
|
291 |
+
5,
|
292 |
+
0.3,
|
293 |
+
0,
|
294 |
+
0.1,
|
295 |
+
4,
|
296 |
+
0,
|
297 |
+
None,
|
298 |
+
],
|
299 |
+
]
|
300 |
+
|
301 |
+
gr.Examples(examples=footnote_examples,
|
302 |
+
label='Footnote examples',
|
303 |
+
inputs=[
|
304 |
+
text_input,
|
305 |
+
negative_prompt,
|
306 |
+
num_segments,
|
307 |
+
segment_threshold,
|
308 |
+
inject_interval,
|
309 |
+
inject_background,
|
310 |
+
seed,
|
311 |
+
color_guidance_weight,
|
312 |
+
rich_text_input,
|
313 |
+
],
|
314 |
+
outputs=[
|
315 |
+
plaintext_result,
|
316 |
+
richtext_result,
|
317 |
+
segments,
|
318 |
+
token_map,
|
319 |
+
],
|
320 |
+
fn=generate,
|
321 |
+
cache_examples=True,
|
322 |
+
examples_per_page=20)
|
323 |
+
with gr.Row():
|
324 |
+
color_examples = [
|
325 |
+
[
|
326 |
+
'{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
327 |
+
'lowres, had anatomy, bad hands, cropped, worst quality',
|
328 |
+
11,
|
329 |
+
0.5,
|
330 |
+
0.3,
|
331 |
+
0.3,
|
332 |
+
6,
|
333 |
+
0.5,
|
334 |
+
None,
|
335 |
+
],
|
336 |
+
[
|
337 |
+
'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
|
338 |
+
'',
|
339 |
+
10,
|
340 |
+
0.5,
|
341 |
+
0.5,
|
342 |
+
0.3,
|
343 |
+
7,
|
344 |
+
0.5,
|
345 |
+
None,
|
346 |
+
],
|
347 |
+
]
|
348 |
+
gr.Examples(examples=color_examples,
|
349 |
+
label='Font color examples',
|
350 |
+
inputs=[
|
351 |
+
text_input,
|
352 |
+
negative_prompt,
|
353 |
+
num_segments,
|
354 |
+
segment_threshold,
|
355 |
+
inject_interval,
|
356 |
+
inject_background,
|
357 |
+
seed,
|
358 |
+
color_guidance_weight,
|
359 |
+
rich_text_input,
|
360 |
+
],
|
361 |
+
outputs=[
|
362 |
+
plaintext_result,
|
363 |
+
richtext_result,
|
364 |
+
segments,
|
365 |
+
token_map,
|
366 |
+
],
|
367 |
+
fn=generate,
|
368 |
+
cache_examples=True,
|
369 |
+
examples_per_page=20)
|
370 |
+
|
371 |
+
with gr.Row():
|
372 |
+
style_examples = [
|
373 |
+
[
|
374 |
+
'{"ops":[{"insert":"a beautiful"},{"attributes":{"font":"mirza"},"insert":" garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain"},{"insert":" in the background"}]}',
|
375 |
+
'',
|
376 |
+
10,
|
377 |
+
0.6,
|
378 |
+
0,
|
379 |
+
0.4,
|
380 |
+
5,
|
381 |
+
0,
|
382 |
+
None,
|
383 |
+
],
|
384 |
+
[
|
385 |
+
'{"ops":[{"insert":"a night"},{"attributes":{"font":"slabo"},"insert":" sky"},{"insert":" filled with stars above a turbulent"},{"attributes":{"font":"roboto"},"insert":" sea"},{"insert":" with giant waves"}]}',
|
386 |
+
'',
|
387 |
+
2,
|
388 |
+
0.6,
|
389 |
+
0,
|
390 |
+
0,
|
391 |
+
6,
|
392 |
+
0.5,
|
393 |
+
None,
|
394 |
+
],
|
395 |
+
]
|
396 |
+
gr.Examples(examples=style_examples,
|
397 |
+
label='Font style examples',
|
398 |
+
inputs=[
|
399 |
+
text_input,
|
400 |
+
negative_prompt,
|
401 |
+
num_segments,
|
402 |
+
segment_threshold,
|
403 |
+
inject_interval,
|
404 |
+
inject_background,
|
405 |
+
seed,
|
406 |
+
color_guidance_weight,
|
407 |
+
rich_text_input,
|
408 |
+
],
|
409 |
+
outputs=[
|
410 |
+
plaintext_result,
|
411 |
+
richtext_result,
|
412 |
+
segments,
|
413 |
+
token_map,
|
414 |
+
],
|
415 |
+
fn=generate,
|
416 |
+
cache_examples=True,
|
417 |
+
examples_per_page=20)
|
418 |
+
|
419 |
+
with gr.Row():
|
420 |
+
size_examples = [
|
421 |
+
[
|
422 |
+
'{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": " pepperoni, and mushroom on the top"}]}',
|
423 |
+
'',
|
424 |
+
5,
|
425 |
+
0.3,
|
426 |
+
0,
|
427 |
+
0,
|
428 |
+
3,
|
429 |
+
1,
|
430 |
+
None,
|
431 |
+
],
|
432 |
+
[
|
433 |
+
'{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "60px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top"}]}',
|
434 |
+
'',
|
435 |
+
5,
|
436 |
+
0.3,
|
437 |
+
0,
|
438 |
+
0,
|
439 |
+
3,
|
440 |
+
1,
|
441 |
+
None,
|
442 |
+
],
|
443 |
+
[
|
444 |
+
'{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "60px"}, "insert": "mushroom"}, {"insert": " on the top"}]}',
|
445 |
+
'',
|
446 |
+
5,
|
447 |
+
0.3,
|
448 |
+
0,
|
449 |
+
0,
|
450 |
+
3,
|
451 |
+
1,
|
452 |
+
None,
|
453 |
+
],
|
454 |
+
]
|
455 |
+
gr.Examples(examples=size_examples,
|
456 |
+
label='Font size examples',
|
457 |
+
inputs=[
|
458 |
+
text_input,
|
459 |
+
negative_prompt,
|
460 |
+
num_segments,
|
461 |
+
segment_threshold,
|
462 |
+
inject_interval,
|
463 |
+
inject_background,
|
464 |
+
seed,
|
465 |
+
color_guidance_weight,
|
466 |
+
rich_text_input,
|
467 |
+
],
|
468 |
+
outputs=[
|
469 |
+
plaintext_result,
|
470 |
+
richtext_result,
|
471 |
+
segments,
|
472 |
+
token_map,
|
473 |
+
],
|
474 |
+
fn=generate,
|
475 |
+
cache_examples=True,
|
476 |
+
examples_per_page=20)
|
477 |
+
generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
|
478 |
+
fn=generate,
|
479 |
+
inputs=[
|
480 |
+
text_input,
|
481 |
+
negative_prompt,
|
482 |
+
num_segments,
|
483 |
+
segment_threshold,
|
484 |
+
inject_interval,
|
485 |
+
inject_background,
|
486 |
+
seed,
|
487 |
+
color_guidance_weight,
|
488 |
+
rich_text_input,
|
489 |
+
height,
|
490 |
+
width,
|
491 |
+
steps,
|
492 |
+
guidance_weight,
|
493 |
+
],
|
494 |
+
outputs=[plaintext_result, richtext_result, segments, token_map],
|
495 |
+
_js=get_js_data
|
496 |
+
).then(
|
497 |
+
fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
|
498 |
+
text_input.change(
|
499 |
+
fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
|
500 |
+
# load url param prompt to textinput
|
501 |
+
load_params_button.click(fn=lambda x: x['prompt'], inputs=[
|
502 |
+
url_params], outputs=[text_input], queue=False)
|
503 |
+
demo.load(
|
504 |
+
fn=load_url_params,
|
505 |
+
inputs=[url_params],
|
506 |
+
outputs=[load_params_button, url_params],
|
507 |
+
_js=get_window_url_params
|
508 |
+
)
|
509 |
+
demo.queue(concurrency_count=1)
|
510 |
+
demo.launch(share=False)
|
511 |
+
|
512 |
+
|
513 |
+
if __name__ == "__main__":
|
514 |
+
main()
|
app_sd.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
from models.region_diffusion import RegionDiffusion
|
12 |
+
from utils.attention_utils import get_token_maps
|
13 |
+
from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
|
14 |
+
get_attention_control_input, get_gradient_guidance_input
|
15 |
+
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from PIL import Image, ImageOps
|
19 |
+
from share_btn import community_icon_html, loading_icon_html, share_js, css
|
20 |
+
|
21 |
+
|
22 |
+
help_text = """
|
23 |
+
If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
|
24 |
+
1. If you format only a portion of a word rather than the complete word, an error may occur.
|
25 |
+
2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
|
26 |
+
3. Consider using a different seed.
|
27 |
+
"""
|
28 |
+
|
29 |
+
|
30 |
+
canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
|
31 |
+
get_js_data = """
|
32 |
+
async (text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, rich_text_input, background_aug) => {
|
33 |
+
const richEl = document.getElementById("rich-text-root");
|
34 |
+
const data = richEl? richEl.contentDocument.body._data : {};
|
35 |
+
return [text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, JSON.stringify(data), background_aug];
|
36 |
+
}
|
37 |
+
"""
|
38 |
+
set_js_data = """
|
39 |
+
async (text_input) => {
|
40 |
+
const richEl = document.getElementById("rich-text-root");
|
41 |
+
const data = text_input ? JSON.parse(text_input) : null;
|
42 |
+
if (richEl && data) richEl.contentDocument.body.setQuillContents(data);
|
43 |
+
}
|
44 |
+
"""
|
45 |
+
|
46 |
+
get_window_url_params = """
|
47 |
+
async (url_params) => {
|
48 |
+
const params = new URLSearchParams(window.location.search);
|
49 |
+
url_params = Object.fromEntries(params);
|
50 |
+
return [url_params];
|
51 |
+
}
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
def load_url_params(url_params):
|
56 |
+
if 'prompt' in url_params:
|
57 |
+
return gr.update(visible=True), url_params
|
58 |
+
else:
|
59 |
+
return gr.update(visible=False), url_params
|
60 |
+
|
61 |
+
|
62 |
+
def main():
|
63 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
64 |
+
model = RegionDiffusion(device)
|
65 |
+
|
66 |
+
def generate(
|
67 |
+
text_input: str,
|
68 |
+
negative_text: str,
|
69 |
+
height: int,
|
70 |
+
width: int,
|
71 |
+
seed: int,
|
72 |
+
steps: int,
|
73 |
+
num_segments: int,
|
74 |
+
segment_threshold: float,
|
75 |
+
inject_interval: float,
|
76 |
+
guidance_weight: float,
|
77 |
+
color_guidance_weight: float,
|
78 |
+
rich_text_input: str,
|
79 |
+
background_aug: bool,
|
80 |
+
):
|
81 |
+
run_dir = 'results/'
|
82 |
+
os.makedirs(run_dir, exist_ok=True)
|
83 |
+
# Load region diffusion model.
|
84 |
+
height = int(height)
|
85 |
+
width = int(width)
|
86 |
+
steps = 41 if not steps else steps
|
87 |
+
guidance_weight = 8.5 if not guidance_weight else guidance_weight
|
88 |
+
text_input = rich_text_input if rich_text_input != '' else text_input
|
89 |
+
print('text_input', text_input)
|
90 |
+
if (text_input == '' or rich_text_input == ''):
|
91 |
+
raise gr.Error("Please enter some text.")
|
92 |
+
# parse json to span attributes
|
93 |
+
base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
|
94 |
+
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
|
95 |
+
json.loads(text_input))
|
96 |
+
|
97 |
+
# create control input for region diffusion
|
98 |
+
region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
|
99 |
+
model, base_text_prompt, style_text_prompts, footnote_text_prompts,
|
100 |
+
footnote_target_tokens, color_text_prompts, color_names)
|
101 |
+
|
102 |
+
# create control input for cross attention
|
103 |
+
text_format_dict = get_attention_control_input(
|
104 |
+
model, base_tokens, size_text_prompts_and_sizes)
|
105 |
+
|
106 |
+
# create control input for region guidance
|
107 |
+
text_format_dict, color_target_token_ids = get_gradient_guidance_input(
|
108 |
+
model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, color_guidance_weight=color_guidance_weight)
|
109 |
+
|
110 |
+
seed_everything(seed)
|
111 |
+
|
112 |
+
# get token maps from plain text to image generation.
|
113 |
+
begin_time = time.time()
|
114 |
+
if model.selfattn_maps is None and model.crossattn_maps is None:
|
115 |
+
model.remove_tokenmap_hooks()
|
116 |
+
model.register_tokenmap_hooks()
|
117 |
+
else:
|
118 |
+
model.reset_attention_maps()
|
119 |
+
model.remove_tokenmap_hooks()
|
120 |
+
plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
|
121 |
+
height=height, width=width, num_inference_steps=steps,
|
122 |
+
guidance_scale=guidance_weight)
|
123 |
+
print('time lapses to get attention maps: %.4f' %
|
124 |
+
(time.time()-begin_time))
|
125 |
+
seed_everything(seed)
|
126 |
+
color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
|
127 |
+
512//8, 512//8, color_target_token_ids[:-1], seed,
|
128 |
+
base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
|
129 |
+
return_vis=True)
|
130 |
+
seed_everything(seed)
|
131 |
+
model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
|
132 |
+
512//8, 512//8, region_target_token_ids[:-1], seed,
|
133 |
+
base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
|
134 |
+
return_vis=True)
|
135 |
+
color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
|
136 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
137 |
+
antialias=True)
|
138 |
+
for color_obj_mask in color_obj_masks]
|
139 |
+
text_format_dict['color_obj_atten'] = color_obj_masks
|
140 |
+
model.remove_tokenmap_hooks()
|
141 |
+
|
142 |
+
# generate image from rich text
|
143 |
+
begin_time = time.time()
|
144 |
+
seed_everything(seed)
|
145 |
+
if background_aug:
|
146 |
+
bg_aug_end = 500
|
147 |
+
else:
|
148 |
+
bg_aug_end = 1000
|
149 |
+
rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
|
150 |
+
height=height, width=width, num_inference_steps=steps,
|
151 |
+
guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
|
152 |
+
text_format_dict=text_format_dict, inject_selfattn=inject_interval,
|
153 |
+
bg_aug_end=bg_aug_end)
|
154 |
+
print('time lapses to generate image from rich text: %.4f' %
|
155 |
+
(time.time()-begin_time))
|
156 |
+
return [plain_img[0], rich_img[0], segments_vis, token_maps]
|
157 |
+
|
158 |
+
with gr.Blocks(css=css) as demo:
|
159 |
+
url_params = gr.JSON({}, visible=False, label="URL Params")
|
160 |
+
gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
|
161 |
+
<p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
|
162 |
+
<p> UMD, Adobe, CMU <p/>
|
163 |
+
<p> <a href="https://huggingface.co/spaces/songweig/rich-text-to-image?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="display:inline;"alt="Duplicate Space"></a> | <a href="https://rich-text-to-image.github.io">[Website]</a> | <a href="https://github.com/SongweiGe/rich-text-to-image">[Code]</a> | <a href="https://arxiv.org/abs/2304.06720">[Paper]</a><p/>
|
164 |
+
<p> For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""")
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column():
|
167 |
+
rich_text_el = gr.HTML(canvas_html, elem_id="canvas_html")
|
168 |
+
rich_text_input = gr.Textbox(value="", visible=False)
|
169 |
+
text_input = gr.Textbox(
|
170 |
+
label='Rich-text JSON Input',
|
171 |
+
visible=False,
|
172 |
+
max_lines=1,
|
173 |
+
placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'',
|
174 |
+
elem_id="text_input"
|
175 |
+
)
|
176 |
+
negative_prompt = gr.Textbox(
|
177 |
+
label='Negative Prompt',
|
178 |
+
max_lines=1,
|
179 |
+
placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
|
180 |
+
elem_id="negative_prompt"
|
181 |
+
)
|
182 |
+
segment_threshold = gr.Slider(label='Token map threshold',
|
183 |
+
info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
|
184 |
+
minimum=0,
|
185 |
+
maximum=1,
|
186 |
+
step=0.01,
|
187 |
+
value=0.25)
|
188 |
+
inject_interval = gr.Slider(label='Detail preservation',
|
189 |
+
info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
|
190 |
+
minimum=0,
|
191 |
+
maximum=1,
|
192 |
+
step=0.01,
|
193 |
+
value=0.)
|
194 |
+
color_guidance_weight = gr.Slider(label='Color weight',
|
195 |
+
info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
|
196 |
+
minimum=0,
|
197 |
+
maximum=2,
|
198 |
+
step=0.1,
|
199 |
+
value=0.5)
|
200 |
+
num_segments = gr.Slider(label='Number of segments',
|
201 |
+
minimum=2,
|
202 |
+
maximum=20,
|
203 |
+
step=1,
|
204 |
+
value=9)
|
205 |
+
seed = gr.Slider(label='Seed',
|
206 |
+
minimum=0,
|
207 |
+
maximum=100000,
|
208 |
+
step=1,
|
209 |
+
value=6,
|
210 |
+
elem_id="seed"
|
211 |
+
)
|
212 |
+
background_aug = gr.Checkbox(
|
213 |
+
label='Precise region alignment',
|
214 |
+
info='(For strict region alignment, select this option, but beware of potential artifacts when using with style.)',
|
215 |
+
value=True)
|
216 |
+
with gr.Accordion('Other Parameters', open=False):
|
217 |
+
steps = gr.Slider(label='Number of Steps',
|
218 |
+
minimum=0,
|
219 |
+
maximum=500,
|
220 |
+
step=1,
|
221 |
+
value=41)
|
222 |
+
guidance_weight = gr.Slider(label='CFG weight',
|
223 |
+
minimum=0,
|
224 |
+
maximum=50,
|
225 |
+
step=0.1,
|
226 |
+
value=8.5)
|
227 |
+
width = gr.Dropdown(choices=[512],
|
228 |
+
value=512,
|
229 |
+
label='Width',
|
230 |
+
visible=True)
|
231 |
+
height = gr.Dropdown(choices=[512],
|
232 |
+
value=512,
|
233 |
+
label='height',
|
234 |
+
visible=True)
|
235 |
+
|
236 |
+
with gr.Row():
|
237 |
+
with gr.Column(scale=1, min_width=100):
|
238 |
+
generate_button = gr.Button("Generate")
|
239 |
+
load_params_button = gr.Button(
|
240 |
+
"Load from URL Params", visible=True)
|
241 |
+
with gr.Column():
|
242 |
+
richtext_result = gr.Image(
|
243 |
+
label='Rich-text', elem_id="rich-text-image")
|
244 |
+
richtext_result.style(height=512)
|
245 |
+
with gr.Row():
|
246 |
+
plaintext_result = gr.Image(
|
247 |
+
label='Plain-text', elem_id="plain-text-image")
|
248 |
+
segments = gr.Image(label='Segmentation')
|
249 |
+
with gr.Row():
|
250 |
+
token_map = gr.Image(label='Token Maps')
|
251 |
+
with gr.Row(visible=False) as share_row:
|
252 |
+
with gr.Group(elem_id="share-btn-container"):
|
253 |
+
community_icon = gr.HTML(community_icon_html)
|
254 |
+
loading_icon = gr.HTML(loading_icon_html)
|
255 |
+
share_button = gr.Button(
|
256 |
+
"Share to community", elem_id="share-btn")
|
257 |
+
share_button.click(None, [], [], _js=share_js)
|
258 |
+
with gr.Row():
|
259 |
+
gr.Markdown(help_text)
|
260 |
+
|
261 |
+
with gr.Row():
|
262 |
+
footnote_examples = [
|
263 |
+
[
|
264 |
+
'{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
|
265 |
+
'',
|
266 |
+
5,
|
267 |
+
0.3,
|
268 |
+
0,
|
269 |
+
6,
|
270 |
+
1,
|
271 |
+
None,
|
272 |
+
True
|
273 |
+
],
|
274 |
+
[
|
275 |
+
'{"ops":[{"insert":"A "},{"attributes":{"link":"kitchen island with a stove with gas burners and a built-in oven "},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, Ζ 1.8, 35 mm, 8k, medium - format print"}]}',
|
276 |
+
'',
|
277 |
+
6,
|
278 |
+
0.5,
|
279 |
+
0,
|
280 |
+
6,
|
281 |
+
1,
|
282 |
+
None,
|
283 |
+
True
|
284 |
+
],
|
285 |
+
[
|
286 |
+
'{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
|
287 |
+
'',
|
288 |
+
4,
|
289 |
+
0.3,
|
290 |
+
0,
|
291 |
+
4,
|
292 |
+
1,
|
293 |
+
None,
|
294 |
+
True
|
295 |
+
],
|
296 |
+
]
|
297 |
+
|
298 |
+
gr.Examples(examples=footnote_examples,
|
299 |
+
label='Footnote examples',
|
300 |
+
inputs=[
|
301 |
+
text_input,
|
302 |
+
negative_prompt,
|
303 |
+
num_segments,
|
304 |
+
segment_threshold,
|
305 |
+
inject_interval,
|
306 |
+
seed,
|
307 |
+
color_guidance_weight,
|
308 |
+
rich_text_input,
|
309 |
+
background_aug,
|
310 |
+
],
|
311 |
+
outputs=[
|
312 |
+
plaintext_result,
|
313 |
+
richtext_result,
|
314 |
+
segments,
|
315 |
+
token_map,
|
316 |
+
],
|
317 |
+
fn=generate,
|
318 |
+
# cache_examples=True,
|
319 |
+
examples_per_page=20)
|
320 |
+
with gr.Row():
|
321 |
+
color_examples = [
|
322 |
+
[
|
323 |
+
'{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#00ffff"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
324 |
+
'lowres, had anatomy, bad hands, cropped, worst quality',
|
325 |
+
9,
|
326 |
+
0.25,
|
327 |
+
0.3,
|
328 |
+
6,
|
329 |
+
0.5,
|
330 |
+
None,
|
331 |
+
True
|
332 |
+
],
|
333 |
+
[
|
334 |
+
'{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#eeeeee"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
335 |
+
'lowres, had anatomy, bad hands, cropped, worst quality',
|
336 |
+
9,
|
337 |
+
0.25,
|
338 |
+
0.3,
|
339 |
+
6,
|
340 |
+
0.1,
|
341 |
+
None,
|
342 |
+
True
|
343 |
+
],
|
344 |
+
[
|
345 |
+
'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
|
346 |
+
'',
|
347 |
+
5,
|
348 |
+
0.3,
|
349 |
+
0.5,
|
350 |
+
6,
|
351 |
+
0.5,
|
352 |
+
None,
|
353 |
+
False
|
354 |
+
],
|
355 |
+
[
|
356 |
+
'{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
|
357 |
+
'',
|
358 |
+
3,
|
359 |
+
0.3,
|
360 |
+
0,
|
361 |
+
9,
|
362 |
+
1,
|
363 |
+
None,
|
364 |
+
False
|
365 |
+
],
|
366 |
+
[
|
367 |
+
'{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
|
368 |
+
'',
|
369 |
+
5,
|
370 |
+
0.3,
|
371 |
+
0,
|
372 |
+
5,
|
373 |
+
0.6,
|
374 |
+
None,
|
375 |
+
False
|
376 |
+
],
|
377 |
+
]
|
378 |
+
gr.Examples(examples=color_examples,
|
379 |
+
label='Font color examples',
|
380 |
+
inputs=[
|
381 |
+
text_input,
|
382 |
+
negative_prompt,
|
383 |
+
num_segments,
|
384 |
+
segment_threshold,
|
385 |
+
inject_interval,
|
386 |
+
seed,
|
387 |
+
color_guidance_weight,
|
388 |
+
rich_text_input,
|
389 |
+
background_aug,
|
390 |
+
],
|
391 |
+
outputs=[
|
392 |
+
plaintext_result,
|
393 |
+
richtext_result,
|
394 |
+
segments,
|
395 |
+
token_map,
|
396 |
+
],
|
397 |
+
fn=generate,
|
398 |
+
# cache_examples=True,
|
399 |
+
examples_per_page=20)
|
400 |
+
|
401 |
+
with gr.Row():
|
402 |
+
style_examples = [
|
403 |
+
[
|
404 |
+
'{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
|
405 |
+
'',
|
406 |
+
10,
|
407 |
+
0.45,
|
408 |
+
0,
|
409 |
+
0.2,
|
410 |
+
3,
|
411 |
+
0.5,
|
412 |
+
None,
|
413 |
+
False
|
414 |
+
],
|
415 |
+
[
|
416 |
+
'{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
|
417 |
+
'worst quality, dark, poor quality',
|
418 |
+
2,
|
419 |
+
0.45,
|
420 |
+
0,
|
421 |
+
9,
|
422 |
+
0.5,
|
423 |
+
None,
|
424 |
+
False
|
425 |
+
],
|
426 |
+
[
|
427 |
+
'{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
|
428 |
+
'',
|
429 |
+
2,
|
430 |
+
0.45,
|
431 |
+
0,
|
432 |
+
0,
|
433 |
+
6,
|
434 |
+
0.5,
|
435 |
+
None,
|
436 |
+
False
|
437 |
+
],
|
438 |
+
]
|
439 |
+
gr.Examples(examples=style_examples,
|
440 |
+
label='Font style examples',
|
441 |
+
inputs=[
|
442 |
+
text_input,
|
443 |
+
negative_prompt,
|
444 |
+
num_segments,
|
445 |
+
segment_threshold,
|
446 |
+
inject_interval,
|
447 |
+
seed,
|
448 |
+
color_guidance_weight,
|
449 |
+
rich_text_input,
|
450 |
+
background_aug,
|
451 |
+
],
|
452 |
+
outputs=[
|
453 |
+
plaintext_result,
|
454 |
+
richtext_result,
|
455 |
+
segments,
|
456 |
+
token_map,
|
457 |
+
],
|
458 |
+
fn=generate,
|
459 |
+
# cache_examples=True,
|
460 |
+
examples_per_page=20)
|
461 |
+
|
462 |
+
with gr.Row():
|
463 |
+
size_examples = [
|
464 |
+
[
|
465 |
+
'{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
|
466 |
+
'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
|
467 |
+
5,
|
468 |
+
0.3,
|
469 |
+
0,
|
470 |
+
13,
|
471 |
+
1,
|
472 |
+
None,
|
473 |
+
False
|
474 |
+
],
|
475 |
+
[
|
476 |
+
'{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
|
477 |
+
'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
|
478 |
+
5,
|
479 |
+
0.3,
|
480 |
+
0,
|
481 |
+
13,
|
482 |
+
1,
|
483 |
+
None,
|
484 |
+
False
|
485 |
+
],
|
486 |
+
[
|
487 |
+
'{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
|
488 |
+
'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
|
489 |
+
5,
|
490 |
+
0.3,
|
491 |
+
0,
|
492 |
+
13,
|
493 |
+
1,
|
494 |
+
None,
|
495 |
+
False
|
496 |
+
],
|
497 |
+
]
|
498 |
+
gr.Examples(examples=size_examples,
|
499 |
+
label='Font size examples',
|
500 |
+
inputs=[
|
501 |
+
text_input,
|
502 |
+
negative_prompt,
|
503 |
+
num_segments,
|
504 |
+
segment_threshold,
|
505 |
+
inject_interval,
|
506 |
+
seed,
|
507 |
+
color_guidance_weight,
|
508 |
+
rich_text_input,
|
509 |
+
background_aug,
|
510 |
+
],
|
511 |
+
outputs=[
|
512 |
+
plaintext_result,
|
513 |
+
richtext_result,
|
514 |
+
segments,
|
515 |
+
token_map,
|
516 |
+
],
|
517 |
+
fn=generate,
|
518 |
+
# cache_examples=True,
|
519 |
+
examples_per_page=20)
|
520 |
+
generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
|
521 |
+
fn=generate,
|
522 |
+
inputs=[
|
523 |
+
text_input,
|
524 |
+
negative_prompt,
|
525 |
+
height,
|
526 |
+
width,
|
527 |
+
seed,
|
528 |
+
steps,
|
529 |
+
num_segments,
|
530 |
+
segment_threshold,
|
531 |
+
inject_interval,
|
532 |
+
guidance_weight,
|
533 |
+
color_guidance_weight,
|
534 |
+
rich_text_input,
|
535 |
+
background_aug
|
536 |
+
],
|
537 |
+
outputs=[plaintext_result, richtext_result, segments, token_map],
|
538 |
+
_js=get_js_data
|
539 |
+
).then(
|
540 |
+
fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
|
541 |
+
text_input.change(
|
542 |
+
fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
|
543 |
+
# load url param prompt to textinput
|
544 |
+
load_params_button.click(fn=lambda x: x['prompt'], inputs=[
|
545 |
+
url_params], outputs=[text_input], queue=False)
|
546 |
+
demo.load(
|
547 |
+
fn=load_url_params,
|
548 |
+
inputs=[url_params],
|
549 |
+
outputs=[load_params_button, url_params],
|
550 |
+
_js=get_window_url_params
|
551 |
+
)
|
552 |
+
demo.queue(concurrency_count=1)
|
553 |
+
demo.launch(share=False)
|
554 |
+
|
555 |
+
|
556 |
+
if __name__ == "__main__":
|
557 |
+
main()
|
models/attention.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import maybe_allow_in_graph
|
21 |
+
from diffusers.models.activations import get_activation
|
22 |
+
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
|
23 |
+
|
24 |
+
from models.attention_processor import Attention
|
25 |
+
|
26 |
+
@maybe_allow_in_graph
|
27 |
+
class BasicTransformerBlock(nn.Module):
|
28 |
+
r"""
|
29 |
+
A basic Transformer block.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
dim (`int`): The number of channels in the input and output.
|
33 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
34 |
+
attention_head_dim (`int`): The number of channels in each head.
|
35 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
36 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
37 |
+
only_cross_attention (`bool`, *optional*):
|
38 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
39 |
+
double_self_attention (`bool`, *optional*):
|
40 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
41 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
42 |
+
num_embeds_ada_norm (:
|
43 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
44 |
+
attention_bias (:
|
45 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
dim: int,
|
51 |
+
num_attention_heads: int,
|
52 |
+
attention_head_dim: int,
|
53 |
+
dropout=0.0,
|
54 |
+
cross_attention_dim: Optional[int] = None,
|
55 |
+
activation_fn: str = "geglu",
|
56 |
+
num_embeds_ada_norm: Optional[int] = None,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
only_cross_attention: bool = False,
|
59 |
+
double_self_attention: bool = False,
|
60 |
+
upcast_attention: bool = False,
|
61 |
+
norm_elementwise_affine: bool = True,
|
62 |
+
norm_type: str = "layer_norm",
|
63 |
+
final_dropout: bool = False,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.only_cross_attention = only_cross_attention
|
67 |
+
|
68 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
69 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
70 |
+
|
71 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
72 |
+
raise ValueError(
|
73 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
74 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
75 |
+
)
|
76 |
+
|
77 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
78 |
+
# 1. Self-Attn
|
79 |
+
if self.use_ada_layer_norm:
|
80 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
81 |
+
elif self.use_ada_layer_norm_zero:
|
82 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
83 |
+
else:
|
84 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
85 |
+
self.attn1 = Attention(
|
86 |
+
query_dim=dim,
|
87 |
+
heads=num_attention_heads,
|
88 |
+
dim_head=attention_head_dim,
|
89 |
+
dropout=dropout,
|
90 |
+
bias=attention_bias,
|
91 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
92 |
+
upcast_attention=upcast_attention,
|
93 |
+
)
|
94 |
+
|
95 |
+
# 2. Cross-Attn
|
96 |
+
if cross_attention_dim is not None or double_self_attention:
|
97 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
98 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
99 |
+
# the second cross attention block.
|
100 |
+
self.norm2 = (
|
101 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
102 |
+
if self.use_ada_layer_norm
|
103 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
104 |
+
)
|
105 |
+
self.attn2 = Attention(
|
106 |
+
query_dim=dim,
|
107 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
108 |
+
heads=num_attention_heads,
|
109 |
+
dim_head=attention_head_dim,
|
110 |
+
dropout=dropout,
|
111 |
+
bias=attention_bias,
|
112 |
+
upcast_attention=upcast_attention,
|
113 |
+
) # is self-attn if encoder_hidden_states is none
|
114 |
+
else:
|
115 |
+
self.norm2 = None
|
116 |
+
self.attn2 = None
|
117 |
+
|
118 |
+
# 3. Feed-forward
|
119 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
120 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
121 |
+
|
122 |
+
# let chunk size default to None
|
123 |
+
self._chunk_size = None
|
124 |
+
self._chunk_dim = 0
|
125 |
+
|
126 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
127 |
+
# Sets chunk feed-forward
|
128 |
+
self._chunk_size = chunk_size
|
129 |
+
self._chunk_dim = dim
|
130 |
+
|
131 |
+
def forward(
|
132 |
+
self,
|
133 |
+
hidden_states: torch.FloatTensor,
|
134 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
135 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
136 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
137 |
+
timestep: Optional[torch.LongTensor] = None,
|
138 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
139 |
+
class_labels: Optional[torch.LongTensor] = None,
|
140 |
+
):
|
141 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
142 |
+
# 1. Self-Attention
|
143 |
+
if self.use_ada_layer_norm:
|
144 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
145 |
+
elif self.use_ada_layer_norm_zero:
|
146 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
147 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
norm_hidden_states = self.norm1(hidden_states)
|
151 |
+
|
152 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
153 |
+
|
154 |
+
# Rich-Text: ignore the attention probs
|
155 |
+
attn_output, _ = self.attn1(
|
156 |
+
norm_hidden_states,
|
157 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
158 |
+
attention_mask=attention_mask,
|
159 |
+
**cross_attention_kwargs,
|
160 |
+
)
|
161 |
+
if self.use_ada_layer_norm_zero:
|
162 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
163 |
+
hidden_states = attn_output + hidden_states
|
164 |
+
|
165 |
+
# 2. Cross-Attention
|
166 |
+
if self.attn2 is not None:
|
167 |
+
norm_hidden_states = (
|
168 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
169 |
+
)
|
170 |
+
|
171 |
+
# Rich-Text: ignore the attention probs
|
172 |
+
attn_output, _ = self.attn2(
|
173 |
+
norm_hidden_states,
|
174 |
+
encoder_hidden_states=encoder_hidden_states,
|
175 |
+
attention_mask=encoder_attention_mask,
|
176 |
+
**cross_attention_kwargs,
|
177 |
+
)
|
178 |
+
hidden_states = attn_output + hidden_states
|
179 |
+
|
180 |
+
# 3. Feed-forward
|
181 |
+
norm_hidden_states = self.norm3(hidden_states)
|
182 |
+
|
183 |
+
if self.use_ada_layer_norm_zero:
|
184 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
185 |
+
|
186 |
+
if self._chunk_size is not None:
|
187 |
+
# "feed_forward_chunk_size" can be used to save memory
|
188 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
189 |
+
raise ValueError(
|
190 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
191 |
+
)
|
192 |
+
|
193 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
194 |
+
ff_output = torch.cat(
|
195 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
196 |
+
dim=self._chunk_dim,
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
ff_output = self.ff(norm_hidden_states)
|
200 |
+
|
201 |
+
if self.use_ada_layer_norm_zero:
|
202 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
203 |
+
|
204 |
+
hidden_states = ff_output + hidden_states
|
205 |
+
|
206 |
+
return hidden_states
|
207 |
+
|
208 |
+
|
209 |
+
class FeedForward(nn.Module):
|
210 |
+
r"""
|
211 |
+
A feed-forward layer.
|
212 |
+
|
213 |
+
Parameters:
|
214 |
+
dim (`int`): The number of channels in the input.
|
215 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
216 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
217 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
218 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
219 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
dim: int,
|
225 |
+
dim_out: Optional[int] = None,
|
226 |
+
mult: int = 4,
|
227 |
+
dropout: float = 0.0,
|
228 |
+
activation_fn: str = "geglu",
|
229 |
+
final_dropout: bool = False,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
inner_dim = int(dim * mult)
|
233 |
+
dim_out = dim_out if dim_out is not None else dim
|
234 |
+
|
235 |
+
if activation_fn == "gelu":
|
236 |
+
act_fn = GELU(dim, inner_dim)
|
237 |
+
if activation_fn == "gelu-approximate":
|
238 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
239 |
+
elif activation_fn == "geglu":
|
240 |
+
act_fn = GEGLU(dim, inner_dim)
|
241 |
+
elif activation_fn == "geglu-approximate":
|
242 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
243 |
+
|
244 |
+
self.net = nn.ModuleList([])
|
245 |
+
# project in
|
246 |
+
self.net.append(act_fn)
|
247 |
+
# project dropout
|
248 |
+
self.net.append(nn.Dropout(dropout))
|
249 |
+
# project out
|
250 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
251 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
252 |
+
if final_dropout:
|
253 |
+
self.net.append(nn.Dropout(dropout))
|
254 |
+
|
255 |
+
def forward(self, hidden_states):
|
256 |
+
for module in self.net:
|
257 |
+
hidden_states = module(hidden_states)
|
258 |
+
return hidden_states
|
259 |
+
|
260 |
+
|
261 |
+
class GELU(nn.Module):
|
262 |
+
r"""
|
263 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
267 |
+
super().__init__()
|
268 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
269 |
+
self.approximate = approximate
|
270 |
+
|
271 |
+
def gelu(self, gate):
|
272 |
+
if gate.device.type != "mps":
|
273 |
+
return F.gelu(gate, approximate=self.approximate)
|
274 |
+
# mps: gelu is not implemented for float16
|
275 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
276 |
+
|
277 |
+
def forward(self, hidden_states):
|
278 |
+
hidden_states = self.proj(hidden_states)
|
279 |
+
hidden_states = self.gelu(hidden_states)
|
280 |
+
return hidden_states
|
281 |
+
|
282 |
+
|
283 |
+
class GEGLU(nn.Module):
|
284 |
+
r"""
|
285 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
286 |
+
|
287 |
+
Parameters:
|
288 |
+
dim_in (`int`): The number of channels in the input.
|
289 |
+
dim_out (`int`): The number of channels in the output.
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(self, dim_in: int, dim_out: int):
|
293 |
+
super().__init__()
|
294 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
295 |
+
|
296 |
+
def gelu(self, gate):
|
297 |
+
if gate.device.type != "mps":
|
298 |
+
return F.gelu(gate)
|
299 |
+
# mps: gelu is not implemented for float16
|
300 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
301 |
+
|
302 |
+
def forward(self, hidden_states):
|
303 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
304 |
+
return hidden_states * self.gelu(gate)
|
305 |
+
|
306 |
+
|
307 |
+
class ApproximateGELU(nn.Module):
|
308 |
+
"""
|
309 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
310 |
+
|
311 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(self, dim_in: int, dim_out: int):
|
315 |
+
super().__init__()
|
316 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
x = self.proj(x)
|
320 |
+
return x * torch.sigmoid(1.702 * x)
|
321 |
+
|
322 |
+
|
323 |
+
class AdaLayerNorm(nn.Module):
|
324 |
+
"""
|
325 |
+
Norm layer modified to incorporate timestep embeddings.
|
326 |
+
"""
|
327 |
+
|
328 |
+
def __init__(self, embedding_dim, num_embeddings):
|
329 |
+
super().__init__()
|
330 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
331 |
+
self.silu = nn.SiLU()
|
332 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
333 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
334 |
+
|
335 |
+
def forward(self, x, timestep):
|
336 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
337 |
+
scale, shift = torch.chunk(emb, 2)
|
338 |
+
x = self.norm(x) * (1 + scale) + shift
|
339 |
+
return x
|
340 |
+
|
341 |
+
|
342 |
+
class AdaLayerNormZero(nn.Module):
|
343 |
+
"""
|
344 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
345 |
+
"""
|
346 |
+
|
347 |
+
def __init__(self, embedding_dim, num_embeddings):
|
348 |
+
super().__init__()
|
349 |
+
|
350 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
351 |
+
|
352 |
+
self.silu = nn.SiLU()
|
353 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
354 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
355 |
+
|
356 |
+
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
357 |
+
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
358 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
359 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
360 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
361 |
+
|
362 |
+
|
363 |
+
class AdaGroupNorm(nn.Module):
|
364 |
+
"""
|
365 |
+
GroupNorm layer modified to incorporate timestep embeddings.
|
366 |
+
"""
|
367 |
+
|
368 |
+
def __init__(
|
369 |
+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
370 |
+
):
|
371 |
+
super().__init__()
|
372 |
+
self.num_groups = num_groups
|
373 |
+
self.eps = eps
|
374 |
+
|
375 |
+
if act_fn is None:
|
376 |
+
self.act = None
|
377 |
+
else:
|
378 |
+
self.act = get_activation(act_fn)
|
379 |
+
|
380 |
+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
381 |
+
|
382 |
+
def forward(self, x, emb):
|
383 |
+
if self.act:
|
384 |
+
emb = self.act(emb)
|
385 |
+
emb = self.linear(emb)
|
386 |
+
emb = emb[:, :, None, None]
|
387 |
+
scale, shift = emb.chunk(2, dim=1)
|
388 |
+
|
389 |
+
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
390 |
+
x = x * (1 + scale) + shift
|
391 |
+
return x
|
models/attention_processor.py
ADDED
@@ -0,0 +1,1687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Callable, Optional, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import deprecate, logging, maybe_allow_in_graph
|
21 |
+
from diffusers.utils.import_utils import is_xformers_available
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
+
|
26 |
+
|
27 |
+
if is_xformers_available():
|
28 |
+
import xformers
|
29 |
+
import xformers.ops
|
30 |
+
else:
|
31 |
+
xformers = None
|
32 |
+
|
33 |
+
|
34 |
+
@maybe_allow_in_graph
|
35 |
+
class Attention(nn.Module):
|
36 |
+
r"""
|
37 |
+
A cross attention layer.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
query_dim (`int`): The number of channels in the query.
|
41 |
+
cross_attention_dim (`int`, *optional*):
|
42 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
43 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
44 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
45 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
46 |
+
bias (`bool`, *optional*, defaults to False):
|
47 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
query_dim: int,
|
53 |
+
cross_attention_dim: Optional[int] = None,
|
54 |
+
heads: int = 8,
|
55 |
+
dim_head: int = 64,
|
56 |
+
dropout: float = 0.0,
|
57 |
+
bias=False,
|
58 |
+
upcast_attention: bool = False,
|
59 |
+
upcast_softmax: bool = False,
|
60 |
+
cross_attention_norm: Optional[str] = None,
|
61 |
+
cross_attention_norm_num_groups: int = 32,
|
62 |
+
added_kv_proj_dim: Optional[int] = None,
|
63 |
+
norm_num_groups: Optional[int] = None,
|
64 |
+
spatial_norm_dim: Optional[int] = None,
|
65 |
+
out_bias: bool = True,
|
66 |
+
scale_qk: bool = True,
|
67 |
+
only_cross_attention: bool = False,
|
68 |
+
eps: float = 1e-5,
|
69 |
+
rescale_output_factor: float = 1.0,
|
70 |
+
residual_connection: bool = False,
|
71 |
+
_from_deprecated_attn_block=False,
|
72 |
+
processor: Optional["AttnProcessor"] = None,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
inner_dim = dim_head * heads
|
76 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
77 |
+
self.upcast_attention = upcast_attention
|
78 |
+
self.upcast_softmax = upcast_softmax
|
79 |
+
self.rescale_output_factor = rescale_output_factor
|
80 |
+
self.residual_connection = residual_connection
|
81 |
+
self.dropout = dropout
|
82 |
+
|
83 |
+
# we make use of this private variable to know whether this class is loaded
|
84 |
+
# with an deprecated state dict so that we can convert it on the fly
|
85 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
86 |
+
|
87 |
+
self.scale_qk = scale_qk
|
88 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
89 |
+
|
90 |
+
self.heads = heads
|
91 |
+
# for slice_size > 0 the attention score computation
|
92 |
+
# is split across the batch axis to save memory
|
93 |
+
# You can set slice_size with `set_attention_slice`
|
94 |
+
self.sliceable_head_dim = heads
|
95 |
+
|
96 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
97 |
+
self.only_cross_attention = only_cross_attention
|
98 |
+
|
99 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
100 |
+
raise ValueError(
|
101 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
102 |
+
)
|
103 |
+
|
104 |
+
if norm_num_groups is not None:
|
105 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
106 |
+
else:
|
107 |
+
self.group_norm = None
|
108 |
+
|
109 |
+
if spatial_norm_dim is not None:
|
110 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
111 |
+
else:
|
112 |
+
self.spatial_norm = None
|
113 |
+
|
114 |
+
if cross_attention_norm is None:
|
115 |
+
self.norm_cross = None
|
116 |
+
elif cross_attention_norm == "layer_norm":
|
117 |
+
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
118 |
+
elif cross_attention_norm == "group_norm":
|
119 |
+
if self.added_kv_proj_dim is not None:
|
120 |
+
# The given `encoder_hidden_states` are initially of shape
|
121 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
122 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
123 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
124 |
+
# the number of channels for the group norm.
|
125 |
+
norm_cross_num_channels = added_kv_proj_dim
|
126 |
+
else:
|
127 |
+
norm_cross_num_channels = cross_attention_dim
|
128 |
+
|
129 |
+
self.norm_cross = nn.GroupNorm(
|
130 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
raise ValueError(
|
134 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
135 |
+
)
|
136 |
+
|
137 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
138 |
+
|
139 |
+
if not self.only_cross_attention:
|
140 |
+
# only relevant for the `AddedKVProcessor` classes
|
141 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
142 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
143 |
+
else:
|
144 |
+
self.to_k = None
|
145 |
+
self.to_v = None
|
146 |
+
|
147 |
+
if self.added_kv_proj_dim is not None:
|
148 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
149 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
150 |
+
|
151 |
+
self.to_out = nn.ModuleList([])
|
152 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
153 |
+
self.to_out.append(nn.Dropout(dropout))
|
154 |
+
|
155 |
+
# set attention processor
|
156 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
157 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
158 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
159 |
+
if processor is None:
|
160 |
+
processor = (
|
161 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
162 |
+
)
|
163 |
+
self.set_processor(processor)
|
164 |
+
|
165 |
+
# Rich-Text: util function for averaging over attention heads
|
166 |
+
def reshape_batch_dim_to_heads_and_average(self, tensor):
|
167 |
+
batch_size, seq_len, seq_len2 = tensor.shape
|
168 |
+
head_size = self.heads
|
169 |
+
tensor = tensor.reshape(batch_size // head_size,
|
170 |
+
head_size, seq_len, seq_len2)
|
171 |
+
return tensor.mean(1)
|
172 |
+
|
173 |
+
def set_use_memory_efficient_attention_xformers(
|
174 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
175 |
+
):
|
176 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
177 |
+
self.processor,
|
178 |
+
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
|
179 |
+
)
|
180 |
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
181 |
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
182 |
+
)
|
183 |
+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
184 |
+
self.processor,
|
185 |
+
(
|
186 |
+
AttnAddedKVProcessor,
|
187 |
+
AttnAddedKVProcessor2_0,
|
188 |
+
SlicedAttnAddedKVProcessor,
|
189 |
+
XFormersAttnAddedKVProcessor,
|
190 |
+
LoRAAttnAddedKVProcessor,
|
191 |
+
),
|
192 |
+
)
|
193 |
+
|
194 |
+
if use_memory_efficient_attention_xformers:
|
195 |
+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
196 |
+
raise NotImplementedError(
|
197 |
+
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
|
198 |
+
)
|
199 |
+
if not is_xformers_available():
|
200 |
+
raise ModuleNotFoundError(
|
201 |
+
(
|
202 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
203 |
+
" xformers"
|
204 |
+
),
|
205 |
+
name="xformers",
|
206 |
+
)
|
207 |
+
elif not torch.cuda.is_available():
|
208 |
+
raise ValueError(
|
209 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
210 |
+
" only available for GPU "
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
try:
|
214 |
+
# Make sure we can run the memory efficient attention
|
215 |
+
_ = xformers.ops.memory_efficient_attention(
|
216 |
+
torch.randn((1, 2, 40), device="cuda"),
|
217 |
+
torch.randn((1, 2, 40), device="cuda"),
|
218 |
+
torch.randn((1, 2, 40), device="cuda"),
|
219 |
+
)
|
220 |
+
except Exception as e:
|
221 |
+
raise e
|
222 |
+
|
223 |
+
if is_lora:
|
224 |
+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
225 |
+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
226 |
+
processor = LoRAXFormersAttnProcessor(
|
227 |
+
hidden_size=self.processor.hidden_size,
|
228 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
229 |
+
rank=self.processor.rank,
|
230 |
+
attention_op=attention_op,
|
231 |
+
)
|
232 |
+
processor.load_state_dict(self.processor.state_dict())
|
233 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
234 |
+
elif is_custom_diffusion:
|
235 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
236 |
+
train_kv=self.processor.train_kv,
|
237 |
+
train_q_out=self.processor.train_q_out,
|
238 |
+
hidden_size=self.processor.hidden_size,
|
239 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
240 |
+
attention_op=attention_op,
|
241 |
+
)
|
242 |
+
processor.load_state_dict(self.processor.state_dict())
|
243 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
244 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
245 |
+
elif is_added_kv_processor:
|
246 |
+
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
247 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
248 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
249 |
+
# throw warning
|
250 |
+
logger.info(
|
251 |
+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
252 |
+
)
|
253 |
+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
254 |
+
else:
|
255 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
256 |
+
else:
|
257 |
+
if is_lora:
|
258 |
+
attn_processor_class = (
|
259 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
260 |
+
)
|
261 |
+
processor = attn_processor_class(
|
262 |
+
hidden_size=self.processor.hidden_size,
|
263 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
264 |
+
rank=self.processor.rank,
|
265 |
+
)
|
266 |
+
processor.load_state_dict(self.processor.state_dict())
|
267 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
268 |
+
elif is_custom_diffusion:
|
269 |
+
processor = CustomDiffusionAttnProcessor(
|
270 |
+
train_kv=self.processor.train_kv,
|
271 |
+
train_q_out=self.processor.train_q_out,
|
272 |
+
hidden_size=self.processor.hidden_size,
|
273 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
274 |
+
)
|
275 |
+
processor.load_state_dict(self.processor.state_dict())
|
276 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
277 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
278 |
+
else:
|
279 |
+
# set attention processor
|
280 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
281 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
282 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
283 |
+
processor = (
|
284 |
+
AttnProcessor2_0()
|
285 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
286 |
+
else AttnProcessor()
|
287 |
+
)
|
288 |
+
|
289 |
+
self.set_processor(processor)
|
290 |
+
|
291 |
+
def set_attention_slice(self, slice_size):
|
292 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
293 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
294 |
+
|
295 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
296 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
297 |
+
elif slice_size is not None:
|
298 |
+
processor = SlicedAttnProcessor(slice_size)
|
299 |
+
elif self.added_kv_proj_dim is not None:
|
300 |
+
processor = AttnAddedKVProcessor()
|
301 |
+
else:
|
302 |
+
# set attention processor
|
303 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
304 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
305 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
306 |
+
processor = (
|
307 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
308 |
+
)
|
309 |
+
|
310 |
+
self.set_processor(processor)
|
311 |
+
|
312 |
+
def set_processor(self, processor: "AttnProcessor"):
|
313 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
314 |
+
# pop `processor` from `self._modules`
|
315 |
+
if (
|
316 |
+
hasattr(self, "processor")
|
317 |
+
and isinstance(self.processor, torch.nn.Module)
|
318 |
+
and not isinstance(processor, torch.nn.Module)
|
319 |
+
):
|
320 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
321 |
+
self._modules.pop("processor")
|
322 |
+
|
323 |
+
self.processor = processor
|
324 |
+
|
325 |
+
# Rich-Text: inject self-attention maps
|
326 |
+
def forward(self, hidden_states, real_attn_probs=None, attn_weights=None, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
327 |
+
# The `Attention` class can call different attention processors / attention functions
|
328 |
+
# here we simply pass along all tensors to the selected processor class
|
329 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
330 |
+
return self.processor(
|
331 |
+
self,
|
332 |
+
hidden_states,
|
333 |
+
real_attn_probs=real_attn_probs,
|
334 |
+
attn_weights=attn_weights,
|
335 |
+
encoder_hidden_states=encoder_hidden_states,
|
336 |
+
attention_mask=attention_mask,
|
337 |
+
**cross_attention_kwargs,
|
338 |
+
)
|
339 |
+
|
340 |
+
def batch_to_head_dim(self, tensor):
|
341 |
+
head_size = self.heads
|
342 |
+
batch_size, seq_len, dim = tensor.shape
|
343 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
344 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
345 |
+
return tensor
|
346 |
+
|
347 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
348 |
+
head_size = self.heads
|
349 |
+
batch_size, seq_len, dim = tensor.shape
|
350 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
351 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
352 |
+
|
353 |
+
if out_dim == 3:
|
354 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
355 |
+
|
356 |
+
return tensor
|
357 |
+
|
358 |
+
# Rich-Text: return attention scores
|
359 |
+
def get_attention_scores(self, query, key, attention_mask=None, attn_weights=False):
|
360 |
+
dtype = query.dtype
|
361 |
+
if self.upcast_attention:
|
362 |
+
query = query.float()
|
363 |
+
key = key.float()
|
364 |
+
|
365 |
+
if attention_mask is None:
|
366 |
+
baddbmm_input = torch.empty(
|
367 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
368 |
+
)
|
369 |
+
beta = 0
|
370 |
+
else:
|
371 |
+
baddbmm_input = attention_mask
|
372 |
+
beta = 1
|
373 |
+
|
374 |
+
attention_scores = torch.baddbmm(
|
375 |
+
baddbmm_input,
|
376 |
+
query,
|
377 |
+
key.transpose(-1, -2),
|
378 |
+
beta=beta,
|
379 |
+
alpha=self.scale,
|
380 |
+
)
|
381 |
+
del baddbmm_input
|
382 |
+
|
383 |
+
if self.upcast_softmax:
|
384 |
+
attention_scores = attention_scores.float()
|
385 |
+
|
386 |
+
# Rich-Text: font size
|
387 |
+
if attn_weights is not None:
|
388 |
+
assert key.shape[1] == 77
|
389 |
+
attention_scores_stable = attention_scores - attention_scores.max(-1, True)[0]
|
390 |
+
attention_score_exp = attention_scores_stable.float().exp()
|
391 |
+
# attention_score_exp = attention_scores.float().exp()
|
392 |
+
font_size_abs, font_size_sign = attn_weights['font_size'].abs(), attn_weights['font_size'].sign()
|
393 |
+
attention_score_exp[:, :, attn_weights['word_pos']] = attention_score_exp[:, :, attn_weights['word_pos']].clone(
|
394 |
+
)*font_size_abs
|
395 |
+
attention_probs = attention_score_exp / attention_score_exp.sum(-1, True)
|
396 |
+
attention_probs[:, :, attn_weights['word_pos']] *= font_size_sign
|
397 |
+
# import ipdb; ipdb.set_trace()
|
398 |
+
if attention_probs.isnan().any():
|
399 |
+
import ipdb; ipdb.set_trace()
|
400 |
+
else:
|
401 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
402 |
+
|
403 |
+
del attention_scores
|
404 |
+
|
405 |
+
attention_probs = attention_probs.to(dtype)
|
406 |
+
|
407 |
+
return attention_probs
|
408 |
+
|
409 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
410 |
+
if batch_size is None:
|
411 |
+
deprecate(
|
412 |
+
"batch_size=None",
|
413 |
+
"0.0.15",
|
414 |
+
(
|
415 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
416 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
417 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
418 |
+
),
|
419 |
+
)
|
420 |
+
batch_size = 1
|
421 |
+
|
422 |
+
head_size = self.heads
|
423 |
+
if attention_mask is None:
|
424 |
+
return attention_mask
|
425 |
+
|
426 |
+
current_length: int = attention_mask.shape[-1]
|
427 |
+
if current_length != target_length:
|
428 |
+
if attention_mask.device.type == "mps":
|
429 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
430 |
+
# Instead, we can manually construct the padding tensor.
|
431 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
432 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
433 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
434 |
+
else:
|
435 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
436 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
437 |
+
# remaining_length: int = target_length - current_length
|
438 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
439 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
440 |
+
|
441 |
+
if out_dim == 3:
|
442 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
443 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
444 |
+
elif out_dim == 4:
|
445 |
+
attention_mask = attention_mask.unsqueeze(1)
|
446 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
447 |
+
|
448 |
+
return attention_mask
|
449 |
+
|
450 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
451 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
452 |
+
|
453 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
454 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
455 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
456 |
+
# Group norm norms along the channels dimension and expects
|
457 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
458 |
+
# to norm along the hidden dimension, so we need to move
|
459 |
+
# (batch_size, sequence_length, hidden_size) ->
|
460 |
+
# (batch_size, hidden_size, sequence_length)
|
461 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
462 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
463 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
464 |
+
else:
|
465 |
+
assert False
|
466 |
+
|
467 |
+
return encoder_hidden_states
|
468 |
+
|
469 |
+
|
470 |
+
class AttnProcessor:
|
471 |
+
r"""
|
472 |
+
Default processor for performing attention-related computations.
|
473 |
+
"""
|
474 |
+
|
475 |
+
# Rich-Text: inject self-attention maps
|
476 |
+
def __call__(
|
477 |
+
self,
|
478 |
+
attn: Attention,
|
479 |
+
hidden_states,
|
480 |
+
real_attn_probs=None,
|
481 |
+
attn_weights=None,
|
482 |
+
encoder_hidden_states=None,
|
483 |
+
attention_mask=None,
|
484 |
+
temb=None,
|
485 |
+
):
|
486 |
+
residual = hidden_states
|
487 |
+
|
488 |
+
if attn.spatial_norm is not None:
|
489 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
490 |
+
|
491 |
+
input_ndim = hidden_states.ndim
|
492 |
+
|
493 |
+
if input_ndim == 4:
|
494 |
+
batch_size, channel, height, width = hidden_states.shape
|
495 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
496 |
+
|
497 |
+
batch_size, sequence_length, _ = (
|
498 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
499 |
+
)
|
500 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
501 |
+
|
502 |
+
if attn.group_norm is not None:
|
503 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
504 |
+
|
505 |
+
query = attn.to_q(hidden_states)
|
506 |
+
|
507 |
+
if encoder_hidden_states is None:
|
508 |
+
encoder_hidden_states = hidden_states
|
509 |
+
elif attn.norm_cross:
|
510 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
511 |
+
|
512 |
+
key = attn.to_k(encoder_hidden_states)
|
513 |
+
value = attn.to_v(encoder_hidden_states)
|
514 |
+
|
515 |
+
query = attn.head_to_batch_dim(query)
|
516 |
+
key = attn.head_to_batch_dim(key)
|
517 |
+
value = attn.head_to_batch_dim(value)
|
518 |
+
|
519 |
+
if real_attn_probs is None:
|
520 |
+
# Rich-Text: font size
|
521 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask, attn_weights=attn_weights)
|
522 |
+
else:
|
523 |
+
# Rich-Text: inject self-attention maps
|
524 |
+
attention_probs = real_attn_probs
|
525 |
+
hidden_states = torch.bmm(attention_probs, value)
|
526 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
527 |
+
|
528 |
+
# linear proj
|
529 |
+
hidden_states = attn.to_out[0](hidden_states)
|
530 |
+
# dropout
|
531 |
+
hidden_states = attn.to_out[1](hidden_states)
|
532 |
+
|
533 |
+
if input_ndim == 4:
|
534 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
535 |
+
|
536 |
+
if attn.residual_connection:
|
537 |
+
hidden_states = hidden_states + residual
|
538 |
+
|
539 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
540 |
+
|
541 |
+
# Rich-Text Modified: return attn probs
|
542 |
+
# We return the map averaged over heads to save memory footprint
|
543 |
+
attention_probs_avg = attn.reshape_batch_dim_to_heads_and_average(
|
544 |
+
attention_probs)
|
545 |
+
return hidden_states, [attention_probs_avg, attention_probs]
|
546 |
+
|
547 |
+
|
548 |
+
class LoRALinearLayer(nn.Module):
|
549 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
550 |
+
super().__init__()
|
551 |
+
|
552 |
+
if rank > min(in_features, out_features):
|
553 |
+
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
554 |
+
|
555 |
+
self.down = nn.Linear(in_features, rank, bias=False)
|
556 |
+
self.up = nn.Linear(rank, out_features, bias=False)
|
557 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
558 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
559 |
+
self.network_alpha = network_alpha
|
560 |
+
self.rank = rank
|
561 |
+
|
562 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
563 |
+
nn.init.zeros_(self.up.weight)
|
564 |
+
|
565 |
+
def forward(self, hidden_states):
|
566 |
+
orig_dtype = hidden_states.dtype
|
567 |
+
dtype = self.down.weight.dtype
|
568 |
+
|
569 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
570 |
+
up_hidden_states = self.up(down_hidden_states)
|
571 |
+
|
572 |
+
if self.network_alpha is not None:
|
573 |
+
up_hidden_states *= self.network_alpha / self.rank
|
574 |
+
|
575 |
+
return up_hidden_states.to(orig_dtype)
|
576 |
+
|
577 |
+
|
578 |
+
class LoRAAttnProcessor(nn.Module):
|
579 |
+
r"""
|
580 |
+
Processor for implementing the LoRA attention mechanism.
|
581 |
+
|
582 |
+
Args:
|
583 |
+
hidden_size (`int`, *optional*):
|
584 |
+
The hidden size of the attention layer.
|
585 |
+
cross_attention_dim (`int`, *optional*):
|
586 |
+
The number of channels in the `encoder_hidden_states`.
|
587 |
+
rank (`int`, defaults to 4):
|
588 |
+
The dimension of the LoRA update matrices.
|
589 |
+
network_alpha (`int`, *optional*):
|
590 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
591 |
+
"""
|
592 |
+
|
593 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
594 |
+
super().__init__()
|
595 |
+
|
596 |
+
self.hidden_size = hidden_size
|
597 |
+
self.cross_attention_dim = cross_attention_dim
|
598 |
+
self.rank = rank
|
599 |
+
|
600 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
601 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
602 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
603 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
604 |
+
|
605 |
+
def __call__(
|
606 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
607 |
+
):
|
608 |
+
residual = hidden_states
|
609 |
+
|
610 |
+
if attn.spatial_norm is not None:
|
611 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
612 |
+
|
613 |
+
input_ndim = hidden_states.ndim
|
614 |
+
|
615 |
+
if input_ndim == 4:
|
616 |
+
batch_size, channel, height, width = hidden_states.shape
|
617 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
618 |
+
|
619 |
+
batch_size, sequence_length, _ = (
|
620 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
621 |
+
)
|
622 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
623 |
+
|
624 |
+
if attn.group_norm is not None:
|
625 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
626 |
+
|
627 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
628 |
+
query = attn.head_to_batch_dim(query)
|
629 |
+
|
630 |
+
if encoder_hidden_states is None:
|
631 |
+
encoder_hidden_states = hidden_states
|
632 |
+
elif attn.norm_cross:
|
633 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
634 |
+
|
635 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
636 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
637 |
+
|
638 |
+
key = attn.head_to_batch_dim(key)
|
639 |
+
value = attn.head_to_batch_dim(value)
|
640 |
+
|
641 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
642 |
+
hidden_states = torch.bmm(attention_probs, value)
|
643 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
644 |
+
|
645 |
+
# linear proj
|
646 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
647 |
+
# dropout
|
648 |
+
hidden_states = attn.to_out[1](hidden_states)
|
649 |
+
|
650 |
+
if input_ndim == 4:
|
651 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
652 |
+
|
653 |
+
if attn.residual_connection:
|
654 |
+
hidden_states = hidden_states + residual
|
655 |
+
|
656 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
657 |
+
|
658 |
+
return hidden_states
|
659 |
+
|
660 |
+
|
661 |
+
class CustomDiffusionAttnProcessor(nn.Module):
|
662 |
+
r"""
|
663 |
+
Processor for implementing attention for the Custom Diffusion method.
|
664 |
+
|
665 |
+
Args:
|
666 |
+
train_kv (`bool`, defaults to `True`):
|
667 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
668 |
+
train_q_out (`bool`, defaults to `True`):
|
669 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
670 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
671 |
+
The hidden size of the attention layer.
|
672 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
673 |
+
The number of channels in the `encoder_hidden_states`.
|
674 |
+
out_bias (`bool`, defaults to `True`):
|
675 |
+
Whether to include the bias parameter in `train_q_out`.
|
676 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
677 |
+
The dropout probability to use.
|
678 |
+
"""
|
679 |
+
|
680 |
+
def __init__(
|
681 |
+
self,
|
682 |
+
train_kv=True,
|
683 |
+
train_q_out=True,
|
684 |
+
hidden_size=None,
|
685 |
+
cross_attention_dim=None,
|
686 |
+
out_bias=True,
|
687 |
+
dropout=0.0,
|
688 |
+
):
|
689 |
+
super().__init__()
|
690 |
+
self.train_kv = train_kv
|
691 |
+
self.train_q_out = train_q_out
|
692 |
+
|
693 |
+
self.hidden_size = hidden_size
|
694 |
+
self.cross_attention_dim = cross_attention_dim
|
695 |
+
|
696 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
697 |
+
if self.train_kv:
|
698 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
699 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
700 |
+
if self.train_q_out:
|
701 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
702 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
703 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
704 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
705 |
+
|
706 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
707 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
708 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
709 |
+
if self.train_q_out:
|
710 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
711 |
+
else:
|
712 |
+
query = attn.to_q(hidden_states)
|
713 |
+
|
714 |
+
if encoder_hidden_states is None:
|
715 |
+
crossattn = False
|
716 |
+
encoder_hidden_states = hidden_states
|
717 |
+
else:
|
718 |
+
crossattn = True
|
719 |
+
if attn.norm_cross:
|
720 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
721 |
+
|
722 |
+
if self.train_kv:
|
723 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
724 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
725 |
+
else:
|
726 |
+
key = attn.to_k(encoder_hidden_states)
|
727 |
+
value = attn.to_v(encoder_hidden_states)
|
728 |
+
|
729 |
+
if crossattn:
|
730 |
+
detach = torch.ones_like(key)
|
731 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
732 |
+
key = detach * key + (1 - detach) * key.detach()
|
733 |
+
value = detach * value + (1 - detach) * value.detach()
|
734 |
+
|
735 |
+
query = attn.head_to_batch_dim(query)
|
736 |
+
key = attn.head_to_batch_dim(key)
|
737 |
+
value = attn.head_to_batch_dim(value)
|
738 |
+
|
739 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
740 |
+
hidden_states = torch.bmm(attention_probs, value)
|
741 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
742 |
+
|
743 |
+
if self.train_q_out:
|
744 |
+
# linear proj
|
745 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
746 |
+
# dropout
|
747 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
748 |
+
else:
|
749 |
+
# linear proj
|
750 |
+
hidden_states = attn.to_out[0](hidden_states)
|
751 |
+
# dropout
|
752 |
+
hidden_states = attn.to_out[1](hidden_states)
|
753 |
+
|
754 |
+
return hidden_states
|
755 |
+
|
756 |
+
|
757 |
+
class AttnAddedKVProcessor:
|
758 |
+
r"""
|
759 |
+
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
760 |
+
encoder.
|
761 |
+
"""
|
762 |
+
|
763 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
764 |
+
residual = hidden_states
|
765 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
766 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
767 |
+
|
768 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
769 |
+
|
770 |
+
if encoder_hidden_states is None:
|
771 |
+
encoder_hidden_states = hidden_states
|
772 |
+
elif attn.norm_cross:
|
773 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
774 |
+
|
775 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
776 |
+
|
777 |
+
query = attn.to_q(hidden_states)
|
778 |
+
query = attn.head_to_batch_dim(query)
|
779 |
+
|
780 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
781 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
782 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
783 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
784 |
+
|
785 |
+
if not attn.only_cross_attention:
|
786 |
+
key = attn.to_k(hidden_states)
|
787 |
+
value = attn.to_v(hidden_states)
|
788 |
+
key = attn.head_to_batch_dim(key)
|
789 |
+
value = attn.head_to_batch_dim(value)
|
790 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
791 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
792 |
+
else:
|
793 |
+
key = encoder_hidden_states_key_proj
|
794 |
+
value = encoder_hidden_states_value_proj
|
795 |
+
|
796 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
797 |
+
hidden_states = torch.bmm(attention_probs, value)
|
798 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
799 |
+
|
800 |
+
# linear proj
|
801 |
+
hidden_states = attn.to_out[0](hidden_states)
|
802 |
+
# dropout
|
803 |
+
hidden_states = attn.to_out[1](hidden_states)
|
804 |
+
|
805 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
806 |
+
hidden_states = hidden_states + residual
|
807 |
+
|
808 |
+
return hidden_states
|
809 |
+
|
810 |
+
|
811 |
+
class AttnAddedKVProcessor2_0:
|
812 |
+
r"""
|
813 |
+
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
814 |
+
learnable key and value matrices for the text encoder.
|
815 |
+
"""
|
816 |
+
|
817 |
+
def __init__(self):
|
818 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
819 |
+
raise ImportError(
|
820 |
+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
821 |
+
)
|
822 |
+
|
823 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
824 |
+
residual = hidden_states
|
825 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
826 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
827 |
+
|
828 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
829 |
+
|
830 |
+
if encoder_hidden_states is None:
|
831 |
+
encoder_hidden_states = hidden_states
|
832 |
+
elif attn.norm_cross:
|
833 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
834 |
+
|
835 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
836 |
+
|
837 |
+
query = attn.to_q(hidden_states)
|
838 |
+
query = attn.head_to_batch_dim(query, out_dim=4)
|
839 |
+
|
840 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
841 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
842 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
843 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
844 |
+
|
845 |
+
if not attn.only_cross_attention:
|
846 |
+
key = attn.to_k(hidden_states)
|
847 |
+
value = attn.to_v(hidden_states)
|
848 |
+
key = attn.head_to_batch_dim(key, out_dim=4)
|
849 |
+
value = attn.head_to_batch_dim(value, out_dim=4)
|
850 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
851 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
852 |
+
else:
|
853 |
+
key = encoder_hidden_states_key_proj
|
854 |
+
value = encoder_hidden_states_value_proj
|
855 |
+
|
856 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
857 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
858 |
+
hidden_states = F.scaled_dot_product_attention(
|
859 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
860 |
+
)
|
861 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
862 |
+
|
863 |
+
# linear proj
|
864 |
+
hidden_states = attn.to_out[0](hidden_states)
|
865 |
+
# dropout
|
866 |
+
hidden_states = attn.to_out[1](hidden_states)
|
867 |
+
|
868 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
869 |
+
hidden_states = hidden_states + residual
|
870 |
+
|
871 |
+
return hidden_states
|
872 |
+
|
873 |
+
|
874 |
+
class LoRAAttnAddedKVProcessor(nn.Module):
|
875 |
+
r"""
|
876 |
+
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
877 |
+
encoder.
|
878 |
+
|
879 |
+
Args:
|
880 |
+
hidden_size (`int`, *optional*):
|
881 |
+
The hidden size of the attention layer.
|
882 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
883 |
+
The number of channels in the `encoder_hidden_states`.
|
884 |
+
rank (`int`, defaults to 4):
|
885 |
+
The dimension of the LoRA update matrices.
|
886 |
+
|
887 |
+
"""
|
888 |
+
|
889 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
890 |
+
super().__init__()
|
891 |
+
|
892 |
+
self.hidden_size = hidden_size
|
893 |
+
self.cross_attention_dim = cross_attention_dim
|
894 |
+
self.rank = rank
|
895 |
+
|
896 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
897 |
+
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
898 |
+
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
899 |
+
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
900 |
+
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
901 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
902 |
+
|
903 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
904 |
+
residual = hidden_states
|
905 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
906 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
907 |
+
|
908 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
909 |
+
|
910 |
+
if encoder_hidden_states is None:
|
911 |
+
encoder_hidden_states = hidden_states
|
912 |
+
elif attn.norm_cross:
|
913 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
914 |
+
|
915 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
916 |
+
|
917 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
918 |
+
query = attn.head_to_batch_dim(query)
|
919 |
+
|
920 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
921 |
+
encoder_hidden_states
|
922 |
+
)
|
923 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
924 |
+
encoder_hidden_states
|
925 |
+
)
|
926 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
927 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
928 |
+
|
929 |
+
if not attn.only_cross_attention:
|
930 |
+
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
931 |
+
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
932 |
+
key = attn.head_to_batch_dim(key)
|
933 |
+
value = attn.head_to_batch_dim(value)
|
934 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
935 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
936 |
+
else:
|
937 |
+
key = encoder_hidden_states_key_proj
|
938 |
+
value = encoder_hidden_states_value_proj
|
939 |
+
|
940 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
941 |
+
hidden_states = torch.bmm(attention_probs, value)
|
942 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
943 |
+
|
944 |
+
# linear proj
|
945 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
946 |
+
# dropout
|
947 |
+
hidden_states = attn.to_out[1](hidden_states)
|
948 |
+
|
949 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
950 |
+
hidden_states = hidden_states + residual
|
951 |
+
|
952 |
+
return hidden_states
|
953 |
+
|
954 |
+
|
955 |
+
class XFormersAttnAddedKVProcessor:
|
956 |
+
r"""
|
957 |
+
Processor for implementing memory efficient attention using xFormers.
|
958 |
+
|
959 |
+
Args:
|
960 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
961 |
+
The base
|
962 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
963 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
964 |
+
operator.
|
965 |
+
"""
|
966 |
+
|
967 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
968 |
+
self.attention_op = attention_op
|
969 |
+
|
970 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
971 |
+
residual = hidden_states
|
972 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
973 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
974 |
+
|
975 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
976 |
+
|
977 |
+
if encoder_hidden_states is None:
|
978 |
+
encoder_hidden_states = hidden_states
|
979 |
+
elif attn.norm_cross:
|
980 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
981 |
+
|
982 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
983 |
+
|
984 |
+
query = attn.to_q(hidden_states)
|
985 |
+
query = attn.head_to_batch_dim(query)
|
986 |
+
|
987 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
988 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
989 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
990 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
991 |
+
|
992 |
+
if not attn.only_cross_attention:
|
993 |
+
key = attn.to_k(hidden_states)
|
994 |
+
value = attn.to_v(hidden_states)
|
995 |
+
key = attn.head_to_batch_dim(key)
|
996 |
+
value = attn.head_to_batch_dim(value)
|
997 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
998 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
999 |
+
else:
|
1000 |
+
key = encoder_hidden_states_key_proj
|
1001 |
+
value = encoder_hidden_states_value_proj
|
1002 |
+
|
1003 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1004 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1005 |
+
)
|
1006 |
+
hidden_states = hidden_states.to(query.dtype)
|
1007 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1008 |
+
|
1009 |
+
# linear proj
|
1010 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1011 |
+
# dropout
|
1012 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1013 |
+
|
1014 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1015 |
+
hidden_states = hidden_states + residual
|
1016 |
+
|
1017 |
+
return hidden_states
|
1018 |
+
|
1019 |
+
|
1020 |
+
class XFormersAttnProcessor:
|
1021 |
+
r"""
|
1022 |
+
Processor for implementing memory efficient attention using xFormers.
|
1023 |
+
|
1024 |
+
Args:
|
1025 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1026 |
+
The base
|
1027 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1028 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1029 |
+
operator.
|
1030 |
+
"""
|
1031 |
+
|
1032 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1033 |
+
self.attention_op = attention_op
|
1034 |
+
|
1035 |
+
def __call__(
|
1036 |
+
self,
|
1037 |
+
attn: Attention,
|
1038 |
+
hidden_states: torch.FloatTensor,
|
1039 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1040 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1041 |
+
temb: Optional[torch.FloatTensor] = None,
|
1042 |
+
):
|
1043 |
+
residual = hidden_states
|
1044 |
+
|
1045 |
+
if attn.spatial_norm is not None:
|
1046 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1047 |
+
|
1048 |
+
input_ndim = hidden_states.ndim
|
1049 |
+
|
1050 |
+
if input_ndim == 4:
|
1051 |
+
batch_size, channel, height, width = hidden_states.shape
|
1052 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1053 |
+
|
1054 |
+
batch_size, key_tokens, _ = (
|
1055 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1056 |
+
)
|
1057 |
+
|
1058 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
1059 |
+
if attention_mask is not None:
|
1060 |
+
# expand our mask's singleton query_tokens dimension:
|
1061 |
+
# [batch*heads, 1, key_tokens] ->
|
1062 |
+
# [batch*heads, query_tokens, key_tokens]
|
1063 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1064 |
+
# [batch*heads, query_tokens, key_tokens]
|
1065 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1066 |
+
_, query_tokens, _ = hidden_states.shape
|
1067 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1068 |
+
|
1069 |
+
if attn.group_norm is not None:
|
1070 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1071 |
+
|
1072 |
+
query = attn.to_q(hidden_states)
|
1073 |
+
|
1074 |
+
if encoder_hidden_states is None:
|
1075 |
+
encoder_hidden_states = hidden_states
|
1076 |
+
elif attn.norm_cross:
|
1077 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1078 |
+
|
1079 |
+
key = attn.to_k(encoder_hidden_states)
|
1080 |
+
value = attn.to_v(encoder_hidden_states)
|
1081 |
+
|
1082 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1083 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1084 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1085 |
+
|
1086 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1087 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1088 |
+
)
|
1089 |
+
hidden_states = hidden_states.to(query.dtype)
|
1090 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1091 |
+
|
1092 |
+
# linear proj
|
1093 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1094 |
+
# dropout
|
1095 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1096 |
+
|
1097 |
+
if input_ndim == 4:
|
1098 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1099 |
+
|
1100 |
+
if attn.residual_connection:
|
1101 |
+
hidden_states = hidden_states + residual
|
1102 |
+
|
1103 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1104 |
+
|
1105 |
+
return hidden_states
|
1106 |
+
|
1107 |
+
|
1108 |
+
class AttnProcessor2_0:
|
1109 |
+
r"""
|
1110 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1111 |
+
"""
|
1112 |
+
|
1113 |
+
def __init__(self):
|
1114 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1115 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1116 |
+
|
1117 |
+
def __call__(
|
1118 |
+
self,
|
1119 |
+
attn: Attention,
|
1120 |
+
hidden_states,
|
1121 |
+
encoder_hidden_states=None,
|
1122 |
+
attention_mask=None,
|
1123 |
+
temb=None,
|
1124 |
+
):
|
1125 |
+
residual = hidden_states
|
1126 |
+
|
1127 |
+
if attn.spatial_norm is not None:
|
1128 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1129 |
+
|
1130 |
+
input_ndim = hidden_states.ndim
|
1131 |
+
|
1132 |
+
if input_ndim == 4:
|
1133 |
+
batch_size, channel, height, width = hidden_states.shape
|
1134 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1135 |
+
|
1136 |
+
batch_size, sequence_length, _ = (
|
1137 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1138 |
+
)
|
1139 |
+
inner_dim = hidden_states.shape[-1]
|
1140 |
+
|
1141 |
+
if attention_mask is not None:
|
1142 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1143 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1144 |
+
# (batch, heads, source_length, target_length)
|
1145 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1146 |
+
|
1147 |
+
if attn.group_norm is not None:
|
1148 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1149 |
+
|
1150 |
+
query = attn.to_q(hidden_states)
|
1151 |
+
|
1152 |
+
if encoder_hidden_states is None:
|
1153 |
+
encoder_hidden_states = hidden_states
|
1154 |
+
elif attn.norm_cross:
|
1155 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1156 |
+
|
1157 |
+
key = attn.to_k(encoder_hidden_states)
|
1158 |
+
value = attn.to_v(encoder_hidden_states)
|
1159 |
+
|
1160 |
+
head_dim = inner_dim // attn.heads
|
1161 |
+
|
1162 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1163 |
+
|
1164 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1165 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1166 |
+
|
1167 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1168 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1169 |
+
hidden_states = F.scaled_dot_product_attention(
|
1170 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1174 |
+
hidden_states = hidden_states.to(query.dtype)
|
1175 |
+
|
1176 |
+
# linear proj
|
1177 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1178 |
+
# dropout
|
1179 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1180 |
+
|
1181 |
+
if input_ndim == 4:
|
1182 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1183 |
+
|
1184 |
+
if attn.residual_connection:
|
1185 |
+
hidden_states = hidden_states + residual
|
1186 |
+
|
1187 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1188 |
+
|
1189 |
+
return hidden_states
|
1190 |
+
|
1191 |
+
|
1192 |
+
class LoRAXFormersAttnProcessor(nn.Module):
|
1193 |
+
r"""
|
1194 |
+
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
1195 |
+
|
1196 |
+
Args:
|
1197 |
+
hidden_size (`int`, *optional*):
|
1198 |
+
The hidden size of the attention layer.
|
1199 |
+
cross_attention_dim (`int`, *optional*):
|
1200 |
+
The number of channels in the `encoder_hidden_states`.
|
1201 |
+
rank (`int`, defaults to 4):
|
1202 |
+
The dimension of the LoRA update matrices.
|
1203 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1204 |
+
The base
|
1205 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1206 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1207 |
+
operator.
|
1208 |
+
network_alpha (`int`, *optional*):
|
1209 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1210 |
+
|
1211 |
+
"""
|
1212 |
+
|
1213 |
+
def __init__(
|
1214 |
+
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
|
1215 |
+
):
|
1216 |
+
super().__init__()
|
1217 |
+
|
1218 |
+
self.hidden_size = hidden_size
|
1219 |
+
self.cross_attention_dim = cross_attention_dim
|
1220 |
+
self.rank = rank
|
1221 |
+
self.attention_op = attention_op
|
1222 |
+
|
1223 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1224 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1225 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1226 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1227 |
+
|
1228 |
+
def __call__(
|
1229 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
1230 |
+
):
|
1231 |
+
residual = hidden_states
|
1232 |
+
|
1233 |
+
if attn.spatial_norm is not None:
|
1234 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1235 |
+
|
1236 |
+
input_ndim = hidden_states.ndim
|
1237 |
+
|
1238 |
+
if input_ndim == 4:
|
1239 |
+
batch_size, channel, height, width = hidden_states.shape
|
1240 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1241 |
+
|
1242 |
+
batch_size, sequence_length, _ = (
|
1243 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1244 |
+
)
|
1245 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1246 |
+
|
1247 |
+
if attn.group_norm is not None:
|
1248 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1249 |
+
|
1250 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1251 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1252 |
+
|
1253 |
+
if encoder_hidden_states is None:
|
1254 |
+
encoder_hidden_states = hidden_states
|
1255 |
+
elif attn.norm_cross:
|
1256 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1257 |
+
|
1258 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1259 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1260 |
+
|
1261 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1262 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1263 |
+
|
1264 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1265 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1266 |
+
)
|
1267 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1268 |
+
|
1269 |
+
# linear proj
|
1270 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1271 |
+
# dropout
|
1272 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1273 |
+
|
1274 |
+
if input_ndim == 4:
|
1275 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1276 |
+
|
1277 |
+
if attn.residual_connection:
|
1278 |
+
hidden_states = hidden_states + residual
|
1279 |
+
|
1280 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1281 |
+
|
1282 |
+
return hidden_states
|
1283 |
+
|
1284 |
+
|
1285 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
1286 |
+
r"""
|
1287 |
+
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1288 |
+
attention.
|
1289 |
+
|
1290 |
+
Args:
|
1291 |
+
hidden_size (`int`):
|
1292 |
+
The hidden size of the attention layer.
|
1293 |
+
cross_attention_dim (`int`, *optional*):
|
1294 |
+
The number of channels in the `encoder_hidden_states`.
|
1295 |
+
rank (`int`, defaults to 4):
|
1296 |
+
The dimension of the LoRA update matrices.
|
1297 |
+
network_alpha (`int`, *optional*):
|
1298 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1299 |
+
"""
|
1300 |
+
|
1301 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
1302 |
+
super().__init__()
|
1303 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1304 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1305 |
+
|
1306 |
+
self.hidden_size = hidden_size
|
1307 |
+
self.cross_attention_dim = cross_attention_dim
|
1308 |
+
self.rank = rank
|
1309 |
+
|
1310 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1311 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1312 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1313 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1314 |
+
|
1315 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
1316 |
+
residual = hidden_states
|
1317 |
+
|
1318 |
+
input_ndim = hidden_states.ndim
|
1319 |
+
|
1320 |
+
if input_ndim == 4:
|
1321 |
+
batch_size, channel, height, width = hidden_states.shape
|
1322 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1323 |
+
|
1324 |
+
batch_size, sequence_length, _ = (
|
1325 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1326 |
+
)
|
1327 |
+
inner_dim = hidden_states.shape[-1]
|
1328 |
+
|
1329 |
+
if attention_mask is not None:
|
1330 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1331 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1332 |
+
# (batch, heads, source_length, target_length)
|
1333 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1334 |
+
|
1335 |
+
if attn.group_norm is not None:
|
1336 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1337 |
+
|
1338 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1339 |
+
|
1340 |
+
if encoder_hidden_states is None:
|
1341 |
+
encoder_hidden_states = hidden_states
|
1342 |
+
elif attn.norm_cross:
|
1343 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1344 |
+
|
1345 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1346 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1347 |
+
|
1348 |
+
head_dim = inner_dim // attn.heads
|
1349 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1350 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1351 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1352 |
+
|
1353 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1354 |
+
hidden_states = F.scaled_dot_product_attention(
|
1355 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1356 |
+
)
|
1357 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1358 |
+
hidden_states = hidden_states.to(query.dtype)
|
1359 |
+
|
1360 |
+
# linear proj
|
1361 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1362 |
+
# dropout
|
1363 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1364 |
+
|
1365 |
+
if input_ndim == 4:
|
1366 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1367 |
+
|
1368 |
+
if attn.residual_connection:
|
1369 |
+
hidden_states = hidden_states + residual
|
1370 |
+
|
1371 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1372 |
+
|
1373 |
+
return hidden_states
|
1374 |
+
|
1375 |
+
|
1376 |
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1377 |
+
r"""
|
1378 |
+
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
1379 |
+
|
1380 |
+
Args:
|
1381 |
+
train_kv (`bool`, defaults to `True`):
|
1382 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
1383 |
+
train_q_out (`bool`, defaults to `True`):
|
1384 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
1385 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
1386 |
+
The hidden size of the attention layer.
|
1387 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1388 |
+
The number of channels in the `encoder_hidden_states`.
|
1389 |
+
out_bias (`bool`, defaults to `True`):
|
1390 |
+
Whether to include the bias parameter in `train_q_out`.
|
1391 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
1392 |
+
The dropout probability to use.
|
1393 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1394 |
+
The base
|
1395 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
1396 |
+
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
1397 |
+
"""
|
1398 |
+
|
1399 |
+
def __init__(
|
1400 |
+
self,
|
1401 |
+
train_kv=True,
|
1402 |
+
train_q_out=False,
|
1403 |
+
hidden_size=None,
|
1404 |
+
cross_attention_dim=None,
|
1405 |
+
out_bias=True,
|
1406 |
+
dropout=0.0,
|
1407 |
+
attention_op: Optional[Callable] = None,
|
1408 |
+
):
|
1409 |
+
super().__init__()
|
1410 |
+
self.train_kv = train_kv
|
1411 |
+
self.train_q_out = train_q_out
|
1412 |
+
|
1413 |
+
self.hidden_size = hidden_size
|
1414 |
+
self.cross_attention_dim = cross_attention_dim
|
1415 |
+
self.attention_op = attention_op
|
1416 |
+
|
1417 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
1418 |
+
if self.train_kv:
|
1419 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1420 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1421 |
+
if self.train_q_out:
|
1422 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1423 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
1424 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1425 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1426 |
+
|
1427 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1428 |
+
batch_size, sequence_length, _ = (
|
1429 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1430 |
+
)
|
1431 |
+
|
1432 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1433 |
+
|
1434 |
+
if self.train_q_out:
|
1435 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
1436 |
+
else:
|
1437 |
+
query = attn.to_q(hidden_states)
|
1438 |
+
|
1439 |
+
if encoder_hidden_states is None:
|
1440 |
+
crossattn = False
|
1441 |
+
encoder_hidden_states = hidden_states
|
1442 |
+
else:
|
1443 |
+
crossattn = True
|
1444 |
+
if attn.norm_cross:
|
1445 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1446 |
+
|
1447 |
+
if self.train_kv:
|
1448 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
1449 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
1450 |
+
else:
|
1451 |
+
key = attn.to_k(encoder_hidden_states)
|
1452 |
+
value = attn.to_v(encoder_hidden_states)
|
1453 |
+
|
1454 |
+
if crossattn:
|
1455 |
+
detach = torch.ones_like(key)
|
1456 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1457 |
+
key = detach * key + (1 - detach) * key.detach()
|
1458 |
+
value = detach * value + (1 - detach) * value.detach()
|
1459 |
+
|
1460 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1461 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1462 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1463 |
+
|
1464 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1465 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1466 |
+
)
|
1467 |
+
hidden_states = hidden_states.to(query.dtype)
|
1468 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1469 |
+
|
1470 |
+
if self.train_q_out:
|
1471 |
+
# linear proj
|
1472 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1473 |
+
# dropout
|
1474 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1475 |
+
else:
|
1476 |
+
# linear proj
|
1477 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1478 |
+
# dropout
|
1479 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1480 |
+
return hidden_states
|
1481 |
+
|
1482 |
+
|
1483 |
+
class SlicedAttnProcessor:
|
1484 |
+
r"""
|
1485 |
+
Processor for implementing sliced attention.
|
1486 |
+
|
1487 |
+
Args:
|
1488 |
+
slice_size (`int`, *optional*):
|
1489 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1490 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1491 |
+
"""
|
1492 |
+
|
1493 |
+
def __init__(self, slice_size):
|
1494 |
+
self.slice_size = slice_size
|
1495 |
+
|
1496 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1497 |
+
residual = hidden_states
|
1498 |
+
|
1499 |
+
input_ndim = hidden_states.ndim
|
1500 |
+
|
1501 |
+
if input_ndim == 4:
|
1502 |
+
batch_size, channel, height, width = hidden_states.shape
|
1503 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1504 |
+
|
1505 |
+
batch_size, sequence_length, _ = (
|
1506 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1507 |
+
)
|
1508 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1509 |
+
|
1510 |
+
if attn.group_norm is not None:
|
1511 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1512 |
+
|
1513 |
+
query = attn.to_q(hidden_states)
|
1514 |
+
dim = query.shape[-1]
|
1515 |
+
query = attn.head_to_batch_dim(query)
|
1516 |
+
|
1517 |
+
if encoder_hidden_states is None:
|
1518 |
+
encoder_hidden_states = hidden_states
|
1519 |
+
elif attn.norm_cross:
|
1520 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1521 |
+
|
1522 |
+
key = attn.to_k(encoder_hidden_states)
|
1523 |
+
value = attn.to_v(encoder_hidden_states)
|
1524 |
+
key = attn.head_to_batch_dim(key)
|
1525 |
+
value = attn.head_to_batch_dim(value)
|
1526 |
+
|
1527 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1528 |
+
hidden_states = torch.zeros(
|
1529 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1530 |
+
)
|
1531 |
+
|
1532 |
+
for i in range(batch_size_attention // self.slice_size):
|
1533 |
+
start_idx = i * self.slice_size
|
1534 |
+
end_idx = (i + 1) * self.slice_size
|
1535 |
+
|
1536 |
+
query_slice = query[start_idx:end_idx]
|
1537 |
+
key_slice = key[start_idx:end_idx]
|
1538 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1539 |
+
|
1540 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1541 |
+
|
1542 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1543 |
+
|
1544 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1545 |
+
|
1546 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1547 |
+
|
1548 |
+
# linear proj
|
1549 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1550 |
+
# dropout
|
1551 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1552 |
+
|
1553 |
+
if input_ndim == 4:
|
1554 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1555 |
+
|
1556 |
+
if attn.residual_connection:
|
1557 |
+
hidden_states = hidden_states + residual
|
1558 |
+
|
1559 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1560 |
+
|
1561 |
+
return hidden_states
|
1562 |
+
|
1563 |
+
|
1564 |
+
class SlicedAttnAddedKVProcessor:
|
1565 |
+
r"""
|
1566 |
+
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
1567 |
+
|
1568 |
+
Args:
|
1569 |
+
slice_size (`int`, *optional*):
|
1570 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1571 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1572 |
+
"""
|
1573 |
+
|
1574 |
+
def __init__(self, slice_size):
|
1575 |
+
self.slice_size = slice_size
|
1576 |
+
|
1577 |
+
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
1578 |
+
residual = hidden_states
|
1579 |
+
|
1580 |
+
if attn.spatial_norm is not None:
|
1581 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1582 |
+
|
1583 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1584 |
+
|
1585 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1586 |
+
|
1587 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1588 |
+
|
1589 |
+
if encoder_hidden_states is None:
|
1590 |
+
encoder_hidden_states = hidden_states
|
1591 |
+
elif attn.norm_cross:
|
1592 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1593 |
+
|
1594 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1595 |
+
|
1596 |
+
query = attn.to_q(hidden_states)
|
1597 |
+
dim = query.shape[-1]
|
1598 |
+
query = attn.head_to_batch_dim(query)
|
1599 |
+
|
1600 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1601 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1602 |
+
|
1603 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1604 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1605 |
+
|
1606 |
+
if not attn.only_cross_attention:
|
1607 |
+
key = attn.to_k(hidden_states)
|
1608 |
+
value = attn.to_v(hidden_states)
|
1609 |
+
key = attn.head_to_batch_dim(key)
|
1610 |
+
value = attn.head_to_batch_dim(value)
|
1611 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1612 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1613 |
+
else:
|
1614 |
+
key = encoder_hidden_states_key_proj
|
1615 |
+
value = encoder_hidden_states_value_proj
|
1616 |
+
|
1617 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1618 |
+
hidden_states = torch.zeros(
|
1619 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1620 |
+
)
|
1621 |
+
|
1622 |
+
for i in range(batch_size_attention // self.slice_size):
|
1623 |
+
start_idx = i * self.slice_size
|
1624 |
+
end_idx = (i + 1) * self.slice_size
|
1625 |
+
|
1626 |
+
query_slice = query[start_idx:end_idx]
|
1627 |
+
key_slice = key[start_idx:end_idx]
|
1628 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1629 |
+
|
1630 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1631 |
+
|
1632 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1633 |
+
|
1634 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1635 |
+
|
1636 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1637 |
+
|
1638 |
+
# linear proj
|
1639 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1640 |
+
# dropout
|
1641 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1642 |
+
|
1643 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1644 |
+
hidden_states = hidden_states + residual
|
1645 |
+
|
1646 |
+
return hidden_states
|
1647 |
+
|
1648 |
+
|
1649 |
+
AttentionProcessor = Union[
|
1650 |
+
AttnProcessor,
|
1651 |
+
AttnProcessor2_0,
|
1652 |
+
XFormersAttnProcessor,
|
1653 |
+
SlicedAttnProcessor,
|
1654 |
+
AttnAddedKVProcessor,
|
1655 |
+
SlicedAttnAddedKVProcessor,
|
1656 |
+
AttnAddedKVProcessor2_0,
|
1657 |
+
XFormersAttnAddedKVProcessor,
|
1658 |
+
LoRAAttnProcessor,
|
1659 |
+
LoRAXFormersAttnProcessor,
|
1660 |
+
LoRAAttnProcessor2_0,
|
1661 |
+
LoRAAttnAddedKVProcessor,
|
1662 |
+
CustomDiffusionAttnProcessor,
|
1663 |
+
CustomDiffusionXFormersAttnProcessor,
|
1664 |
+
]
|
1665 |
+
|
1666 |
+
|
1667 |
+
class SpatialNorm(nn.Module):
|
1668 |
+
"""
|
1669 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
1670 |
+
"""
|
1671 |
+
|
1672 |
+
def __init__(
|
1673 |
+
self,
|
1674 |
+
f_channels,
|
1675 |
+
zq_channels,
|
1676 |
+
):
|
1677 |
+
super().__init__()
|
1678 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
1679 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1680 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1681 |
+
|
1682 |
+
def forward(self, f, zq):
|
1683 |
+
f_size = f.shape[-2:]
|
1684 |
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1685 |
+
norm_f = self.norm_layer(f)
|
1686 |
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
1687 |
+
return new_f
|
models/dual_transformer_2d.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from models.transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
19 |
+
|
20 |
+
|
21 |
+
class DualTransformer2DModel(nn.Module):
|
22 |
+
"""
|
23 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
27 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
28 |
+
in_channels (`int`, *optional*):
|
29 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
30 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
31 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
32 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
33 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
34 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
35 |
+
`ImagePositionalEmbeddings`.
|
36 |
+
num_vector_embeds (`int`, *optional*):
|
37 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
38 |
+
Includes the class for the masked latent pixel.
|
39 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
40 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
41 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
42 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
43 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
44 |
+
attention_bias (`bool`, *optional*):
|
45 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
num_attention_heads: int = 16,
|
51 |
+
attention_head_dim: int = 88,
|
52 |
+
in_channels: Optional[int] = None,
|
53 |
+
num_layers: int = 1,
|
54 |
+
dropout: float = 0.0,
|
55 |
+
norm_num_groups: int = 32,
|
56 |
+
cross_attention_dim: Optional[int] = None,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
sample_size: Optional[int] = None,
|
59 |
+
num_vector_embeds: Optional[int] = None,
|
60 |
+
activation_fn: str = "geglu",
|
61 |
+
num_embeds_ada_norm: Optional[int] = None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.transformers = nn.ModuleList(
|
65 |
+
[
|
66 |
+
Transformer2DModel(
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=attention_head_dim,
|
69 |
+
in_channels=in_channels,
|
70 |
+
num_layers=num_layers,
|
71 |
+
dropout=dropout,
|
72 |
+
norm_num_groups=norm_num_groups,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attention_bias=attention_bias,
|
75 |
+
sample_size=sample_size,
|
76 |
+
num_vector_embeds=num_vector_embeds,
|
77 |
+
activation_fn=activation_fn,
|
78 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
79 |
+
)
|
80 |
+
for _ in range(2)
|
81 |
+
]
|
82 |
+
)
|
83 |
+
|
84 |
+
# Variables that can be set by a pipeline:
|
85 |
+
|
86 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
87 |
+
self.mix_ratio = 0.5
|
88 |
+
|
89 |
+
# The shape of `encoder_hidden_states` is expected to be
|
90 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
91 |
+
self.condition_lengths = [77, 257]
|
92 |
+
|
93 |
+
# Which transformer to use to encode which condition.
|
94 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
95 |
+
self.transformer_index_for_condition = [1, 0]
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
hidden_states,
|
100 |
+
encoder_hidden_states,
|
101 |
+
timestep=None,
|
102 |
+
attention_mask=None,
|
103 |
+
cross_attention_kwargs=None,
|
104 |
+
return_dict: bool = True,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
109 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
110 |
+
hidden_states
|
111 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
112 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
113 |
+
self-attention.
|
114 |
+
timestep ( `torch.long`, *optional*):
|
115 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
116 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
117 |
+
Optional attention mask to be applied in Attention
|
118 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
119 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
123 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
124 |
+
returning a tuple, the first element is the sample tensor.
|
125 |
+
"""
|
126 |
+
input_states = hidden_states
|
127 |
+
|
128 |
+
encoded_states = []
|
129 |
+
tokens_start = 0
|
130 |
+
# attention_mask is not used yet
|
131 |
+
for i in range(2):
|
132 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
133 |
+
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
134 |
+
transformer_index = self.transformer_index_for_condition[i]
|
135 |
+
encoded_state = self.transformers[transformer_index](
|
136 |
+
input_states,
|
137 |
+
encoder_hidden_states=condition_state,
|
138 |
+
timestep=timestep,
|
139 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
140 |
+
return_dict=False,
|
141 |
+
)[0]
|
142 |
+
encoded_states.append(encoded_state - input_states)
|
143 |
+
tokens_start += self.condition_lengths[i]
|
144 |
+
|
145 |
+
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
146 |
+
output_states = output_states + input_states
|
147 |
+
|
148 |
+
if not return_dict:
|
149 |
+
return (output_states,)
|
150 |
+
|
151 |
+
return Transformer2DModelOutput(sample=output_states)
|
models/region_diffusion.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import collections
|
4 |
+
import torch.nn as nn
|
5 |
+
from functools import partial
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
7 |
+
from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
8 |
+
from models.unet_2d_condition import UNet2DConditionModel
|
9 |
+
from utils.attention_utils import CrossAttentionLayers, SelfAttentionLayers
|
10 |
+
|
11 |
+
# suppress partial model loading warning
|
12 |
+
logging.set_verbosity_error()
|
13 |
+
|
14 |
+
|
15 |
+
class RegionDiffusion(nn.Module):
|
16 |
+
def __init__(self, device):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.device = device
|
20 |
+
self.num_train_timesteps = 1000
|
21 |
+
self.clip_gradient = False
|
22 |
+
|
23 |
+
print(f'[INFO] loading stable diffusion...')
|
24 |
+
model_id = 'runwayml/stable-diffusion-v1-5'
|
25 |
+
|
26 |
+
self.vae = AutoencoderKL.from_pretrained(
|
27 |
+
model_id, subfolder="vae").to(self.device)
|
28 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
29 |
+
model_id, subfolder='tokenizer')
|
30 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
31 |
+
model_id, subfolder='text_encoder').to(self.device)
|
32 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
33 |
+
model_id, subfolder="unet").to(self.device)
|
34 |
+
|
35 |
+
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
36 |
+
num_train_timesteps=self.num_train_timesteps, skip_prk_steps=True, steps_offset=1)
|
37 |
+
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
38 |
+
|
39 |
+
self.masks = []
|
40 |
+
self.attention_maps = None
|
41 |
+
self.selfattn_maps = None
|
42 |
+
self.crossattn_maps = None
|
43 |
+
self.color_loss = torch.nn.functional.mse_loss
|
44 |
+
self.forward_hooks = []
|
45 |
+
self.forward_replacement_hooks = []
|
46 |
+
|
47 |
+
print(f'[INFO] loaded stable diffusion!')
|
48 |
+
|
49 |
+
def get_text_embeds(self, prompt, negative_prompt):
|
50 |
+
# prompt, negative_prompt: [str]
|
51 |
+
|
52 |
+
# Tokenize text and get embeddings
|
53 |
+
text_input = self.tokenizer(
|
54 |
+
prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
text_embeddings = self.text_encoder(
|
58 |
+
text_input.input_ids.to(self.device))[0]
|
59 |
+
|
60 |
+
# Do the same for unconditional embeddings
|
61 |
+
uncond_input = self.tokenizer(negative_prompt, padding='max_length',
|
62 |
+
max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
uncond_embeddings = self.text_encoder(
|
66 |
+
uncond_input.input_ids.to(self.device))[0]
|
67 |
+
|
68 |
+
# Cat for final embeddings
|
69 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
70 |
+
return text_embeddings
|
71 |
+
|
72 |
+
def get_text_embeds_list(self, prompts):
|
73 |
+
# prompts: [list]
|
74 |
+
text_embeddings = []
|
75 |
+
for prompt in prompts:
|
76 |
+
# Tokenize text and get embeddings
|
77 |
+
text_input = self.tokenizer(
|
78 |
+
[prompt], padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
text_embeddings.append(self.text_encoder(
|
82 |
+
text_input.input_ids.to(self.device))[0])
|
83 |
+
|
84 |
+
return text_embeddings
|
85 |
+
|
86 |
+
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
|
87 |
+
latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, bg_aug_end=1000):
|
88 |
+
|
89 |
+
if latents is None:
|
90 |
+
latents = torch.randn(
|
91 |
+
(1, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
92 |
+
|
93 |
+
if inject_selfattn > 0:
|
94 |
+
latents_reference = latents.clone().detach()
|
95 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
96 |
+
n_styles = text_embeddings.shape[0]-1
|
97 |
+
print(n_styles, len(self.masks))
|
98 |
+
assert n_styles == len(self.masks)
|
99 |
+
|
100 |
+
with torch.autocast('cuda'):
|
101 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
102 |
+
|
103 |
+
# predict the noise residual
|
104 |
+
with torch.no_grad():
|
105 |
+
# tokens without any attributes
|
106 |
+
feat_inject_step = t > (1-inject_selfattn) * 1000
|
107 |
+
noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
|
108 |
+
# text_format_dict={})['sample']
|
109 |
+
)['sample']
|
110 |
+
# tokens without any style or footnote
|
111 |
+
self.register_fontsize_hooks(text_format_dict)
|
112 |
+
noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
|
113 |
+
# text_format_dict=text_format_dict)['sample']
|
114 |
+
)['sample']
|
115 |
+
self.remove_fontsize_hooks()
|
116 |
+
if inject_selfattn > 0 or inject_background > 0:
|
117 |
+
noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
|
118 |
+
# text_format_dict={})['sample']
|
119 |
+
)['sample']
|
120 |
+
self.register_selfattn_hooks(feat_inject_step)
|
121 |
+
noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:],
|
122 |
+
# text_format_dict={})['sample']
|
123 |
+
)['sample']
|
124 |
+
self.remove_selfattn_hooks()
|
125 |
+
noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
|
126 |
+
noise_pred_text = noise_pred_text_cur * self.masks[-1]
|
127 |
+
# tokens with attributes
|
128 |
+
for style_i, mask in enumerate(self.masks[:-1]):
|
129 |
+
if t > bg_aug_end:
|
130 |
+
rand_rgb = torch.rand([1, 3, 1, 1]).cuda()
|
131 |
+
black_background = torch.ones(
|
132 |
+
[1, 3, height, width]).cuda()*rand_rgb
|
133 |
+
black_latent = self.encode_imgs(
|
134 |
+
black_background)
|
135 |
+
noise = torch.randn_like(black_latent)
|
136 |
+
black_latent_noisy = self.scheduler.add_noise(
|
137 |
+
black_latent, noise, t)
|
138 |
+
masked_latent = (
|
139 |
+
mask > 0.001) * latents + (mask < 0.001) * black_latent_noisy
|
140 |
+
noise_pred_uncond_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[:1],
|
141 |
+
text_format_dict={})['sample']
|
142 |
+
else:
|
143 |
+
masked_latent = latents
|
144 |
+
self.register_replacement_hooks(feat_inject_step)
|
145 |
+
noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
|
146 |
+
# text_format_dict={})['sample']
|
147 |
+
)['sample']
|
148 |
+
self.remove_replacement_hooks()
|
149 |
+
noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
|
150 |
+
noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
|
151 |
+
|
152 |
+
# perform guidance
|
153 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
154 |
+
(noise_pred_text - noise_pred_uncond)
|
155 |
+
|
156 |
+
if inject_selfattn > 0:
|
157 |
+
noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
|
158 |
+
(noise_pred_text_refer - noise_pred_uncond_refer)
|
159 |
+
|
160 |
+
# compute the previous noisy sample x_t -> x_t-1
|
161 |
+
latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
|
162 |
+
torch.cat([latents, latents_reference]))[
|
163 |
+
'prev_sample']
|
164 |
+
latents, latents_reference = torch.chunk(
|
165 |
+
latents_reference, 2, dim=0)
|
166 |
+
|
167 |
+
else:
|
168 |
+
# compute the previous noisy sample x_t -> x_t-1
|
169 |
+
latents = self.scheduler.step(noise_pred, t, latents)[
|
170 |
+
'prev_sample']
|
171 |
+
|
172 |
+
# apply guidance
|
173 |
+
if use_guidance and t < text_format_dict['guidance_start_step']:
|
174 |
+
with torch.enable_grad():
|
175 |
+
if not latents.requires_grad:
|
176 |
+
latents.requires_grad = True
|
177 |
+
latents_0 = self.predict_x0(latents, noise_pred, t)
|
178 |
+
latents_inp = 1 / 0.18215 * latents_0
|
179 |
+
imgs = self.vae.decode(latents_inp).sample
|
180 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
181 |
+
# save_path = 'results/font_color/20230425/church_process/orange/'
|
182 |
+
# os.makedirs(save_path, exist_ok=True)
|
183 |
+
# torchvision.utils.save_image(
|
184 |
+
# imgs, os.path.join(save_path, 'step%d.png' % t))
|
185 |
+
# loss = (((imgs - text_format_dict['target_RGB'])*text_format_dict['color_obj_atten'][:, 0])**2).mean()*100
|
186 |
+
loss_total = 0.
|
187 |
+
for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
|
188 |
+
# loss = self.color_loss(
|
189 |
+
# imgs*attn_map[:, 0], rgb_val*attn_map[:, 0])*100
|
190 |
+
avg_rgb = (
|
191 |
+
imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
|
192 |
+
loss = self.color_loss(
|
193 |
+
avg_rgb, rgb_val[:, :, 0, 0])*100
|
194 |
+
# print(loss)
|
195 |
+
loss_total += loss
|
196 |
+
loss_total.backward()
|
197 |
+
latents = (
|
198 |
+
latents - latents.grad * text_format_dict['color_guidance_weight'] * self.masks[0]).detach().clone()
|
199 |
+
|
200 |
+
return latents
|
201 |
+
|
202 |
+
def predict_x0(self, x_t, eps_t, t):
|
203 |
+
alpha_t = self.scheduler.alphas_cumprod[t]
|
204 |
+
return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
|
205 |
+
|
206 |
+
def produce_attn_maps(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
|
207 |
+
guidance_scale=7.5, latents=None):
|
208 |
+
|
209 |
+
if isinstance(prompts, str):
|
210 |
+
prompts = [prompts]
|
211 |
+
|
212 |
+
if isinstance(negative_prompts, str):
|
213 |
+
negative_prompts = [negative_prompts]
|
214 |
+
|
215 |
+
# Prompts -> text embeds
|
216 |
+
text_embeddings = self.get_text_embeds(
|
217 |
+
prompts, negative_prompts) # [2, 77, 768]
|
218 |
+
if latents is None:
|
219 |
+
latents = torch.randn(
|
220 |
+
(text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
221 |
+
|
222 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
223 |
+
self.remove_replacement_hooks()
|
224 |
+
|
225 |
+
with torch.autocast('cuda'):
|
226 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
227 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
228 |
+
latent_model_input = torch.cat([latents] * 2)
|
229 |
+
|
230 |
+
# predict the noise residual
|
231 |
+
with torch.no_grad():
|
232 |
+
noise_pred = self.unet(
|
233 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
234 |
+
|
235 |
+
# perform guidance
|
236 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
237 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
238 |
+
(noise_pred_text - noise_pred_uncond)
|
239 |
+
|
240 |
+
# compute the previous noisy sample x_t -> x_t-1
|
241 |
+
latents = self.scheduler.step(noise_pred, t, latents)[
|
242 |
+
'prev_sample']
|
243 |
+
|
244 |
+
# Img latents -> imgs
|
245 |
+
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
246 |
+
|
247 |
+
# Img to Numpy
|
248 |
+
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
249 |
+
imgs = (imgs * 255).round().astype('uint8')
|
250 |
+
|
251 |
+
return imgs
|
252 |
+
|
253 |
+
def decode_latents(self, latents):
|
254 |
+
|
255 |
+
latents = 1 / 0.18215 * latents
|
256 |
+
|
257 |
+
with torch.no_grad():
|
258 |
+
imgs = self.vae.decode(latents).sample
|
259 |
+
|
260 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
261 |
+
|
262 |
+
return imgs
|
263 |
+
|
264 |
+
def encode_imgs(self, imgs):
|
265 |
+
# imgs: [B, 3, H, W]
|
266 |
+
|
267 |
+
imgs = 2 * imgs - 1
|
268 |
+
|
269 |
+
posterior = self.vae.encode(imgs).latent_dist
|
270 |
+
latents = posterior.sample() * 0.18215
|
271 |
+
|
272 |
+
return latents
|
273 |
+
|
274 |
+
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
|
275 |
+
guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, bg_aug_end=1000):
|
276 |
+
|
277 |
+
if isinstance(prompts, str):
|
278 |
+
prompts = [prompts]
|
279 |
+
|
280 |
+
if isinstance(negative_prompts, str):
|
281 |
+
negative_prompts = [negative_prompts]
|
282 |
+
|
283 |
+
# Prompts -> text embeds
|
284 |
+
text_embeds = self.get_text_embeds(
|
285 |
+
prompts, negative_prompts) # [2, 77, 768]
|
286 |
+
|
287 |
+
# else:
|
288 |
+
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
|
289 |
+
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
|
290 |
+
use_guidance=use_guidance, text_format_dict=text_format_dict,
|
291 |
+
inject_selfattn=inject_selfattn, bg_aug_end=bg_aug_end) # [1, 4, 64, 64]
|
292 |
+
# Img latents -> imgs
|
293 |
+
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
294 |
+
|
295 |
+
# Img to Numpy
|
296 |
+
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
297 |
+
imgs = (imgs * 255).round().astype('uint8')
|
298 |
+
|
299 |
+
return imgs
|
300 |
+
|
301 |
+
def reset_attention_maps(self):
|
302 |
+
r"""Function to reset attention maps.
|
303 |
+
We reset attention maps because we append them while getting hooks
|
304 |
+
to visualize attention maps for every step.
|
305 |
+
"""
|
306 |
+
for key in self.selfattn_maps:
|
307 |
+
self.selfattn_maps[key] = []
|
308 |
+
for key in self.crossattn_maps:
|
309 |
+
self.crossattn_maps[key] = []
|
310 |
+
|
311 |
+
def register_evaluation_hooks(self):
|
312 |
+
r"""Function for registering hooks during evaluation.
|
313 |
+
We mainly store activation maps averaged over queries.
|
314 |
+
"""
|
315 |
+
self.forward_hooks = []
|
316 |
+
|
317 |
+
def save_activations(activations, name, module, inp, out):
|
318 |
+
r"""
|
319 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
320 |
+
"""
|
321 |
+
# out[0] - final output of attention layer
|
322 |
+
# out[1] - attention probability matrix
|
323 |
+
if 'attn2' in name:
|
324 |
+
assert out[1].shape[-1] == 77
|
325 |
+
activations[name].append(out[1].detach().cpu())
|
326 |
+
else:
|
327 |
+
assert out[1].shape[-1] != 77
|
328 |
+
attention_dict = collections.defaultdict(list)
|
329 |
+
for name, module in self.unet.named_modules():
|
330 |
+
leaf_name = name.split('.')[-1]
|
331 |
+
if 'attn' in leaf_name:
|
332 |
+
# Register hook to obtain outputs at every attention layer.
|
333 |
+
self.forward_hooks.append(module.register_forward_hook(
|
334 |
+
partial(save_activations, attention_dict, name)
|
335 |
+
))
|
336 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
337 |
+
self.attention_maps = attention_dict
|
338 |
+
|
339 |
+
def register_selfattn_hooks(self, feat_inject_step=False):
|
340 |
+
r"""Function for registering hooks during evaluation.
|
341 |
+
We mainly store activation maps averaged over queries.
|
342 |
+
"""
|
343 |
+
self.selfattn_forward_hooks = []
|
344 |
+
|
345 |
+
def save_activations(activations, name, module, inp, out):
|
346 |
+
r"""
|
347 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
348 |
+
"""
|
349 |
+
# out[0] - final output of attention layer
|
350 |
+
# out[1] - attention probability matrix
|
351 |
+
if 'attn2' in name:
|
352 |
+
assert out[1][1].shape[-1] == 77
|
353 |
+
# cross attention injection
|
354 |
+
# activations[name] = out[1][1].detach()
|
355 |
+
else:
|
356 |
+
assert out[1][1].shape[-1] != 77
|
357 |
+
activations[name] = out[1][1].detach()
|
358 |
+
|
359 |
+
def save_resnet_activations(activations, name, module, inp, out):
|
360 |
+
r"""
|
361 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
362 |
+
"""
|
363 |
+
# out[0] - final output of residual layer
|
364 |
+
# out[1] - residual hidden feature
|
365 |
+
# import ipdb
|
366 |
+
# ipdb.set_trace()
|
367 |
+
assert out[1].shape[-1] == 16
|
368 |
+
activations[name] = out[1].detach()
|
369 |
+
attention_dict = collections.defaultdict(list)
|
370 |
+
for name, module in self.unet.named_modules():
|
371 |
+
leaf_name = name.split('.')[-1]
|
372 |
+
if 'attn' in leaf_name and feat_inject_step:
|
373 |
+
# Register hook to obtain outputs at every attention layer.
|
374 |
+
self.selfattn_forward_hooks.append(module.register_forward_hook(
|
375 |
+
partial(save_activations, attention_dict, name)
|
376 |
+
))
|
377 |
+
if name == 'up_blocks.1.resnets.1' and feat_inject_step:
|
378 |
+
self.selfattn_forward_hooks.append(module.register_forward_hook(
|
379 |
+
partial(save_resnet_activations, attention_dict, name)
|
380 |
+
))
|
381 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
382 |
+
self.self_attention_maps_cur = attention_dict
|
383 |
+
|
384 |
+
def register_replacement_hooks(self, feat_inject_step=False):
|
385 |
+
r"""Function for registering hooks to replace self attention.
|
386 |
+
"""
|
387 |
+
self.forward_replacement_hooks = []
|
388 |
+
|
389 |
+
def replace_activations(name, module, args):
|
390 |
+
r"""
|
391 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
392 |
+
"""
|
393 |
+
if 'attn1' in name:
|
394 |
+
modified_args = (args[0], self.self_attention_maps_cur[name])
|
395 |
+
return modified_args
|
396 |
+
# cross attention injection
|
397 |
+
# elif 'attn2' in name:
|
398 |
+
# modified_map = {
|
399 |
+
# 'reference': self.self_attention_maps_cur[name],
|
400 |
+
# 'inject_pos': self.inject_pos,
|
401 |
+
# }
|
402 |
+
# modified_args = (args[0], modified_map)
|
403 |
+
# return modified_args
|
404 |
+
|
405 |
+
def replace_resnet_activations(name, module, args):
|
406 |
+
r"""
|
407 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
408 |
+
"""
|
409 |
+
modified_args = (args[0], args[1],
|
410 |
+
self.self_attention_maps_cur[name])
|
411 |
+
return modified_args
|
412 |
+
for name, module in self.unet.named_modules():
|
413 |
+
leaf_name = name.split('.')[-1]
|
414 |
+
if 'attn' in leaf_name and feat_inject_step:
|
415 |
+
# Register hook to obtain outputs at every attention layer.
|
416 |
+
self.forward_replacement_hooks.append(module.register_forward_pre_hook(
|
417 |
+
partial(replace_activations, name)
|
418 |
+
))
|
419 |
+
if name == 'up_blocks.1.resnets.1' and feat_inject_step:
|
420 |
+
# Register hook to obtain outputs at every attention layer.
|
421 |
+
self.forward_replacement_hooks.append(module.register_forward_pre_hook(
|
422 |
+
partial(replace_resnet_activations, name)
|
423 |
+
))
|
424 |
+
|
425 |
+
def register_tokenmap_hooks(self):
|
426 |
+
r"""Function for registering hooks during evaluation.
|
427 |
+
We mainly store activation maps averaged over queries.
|
428 |
+
"""
|
429 |
+
self.forward_hooks = []
|
430 |
+
|
431 |
+
def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
|
432 |
+
r"""
|
433 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
434 |
+
"""
|
435 |
+
# out[0] - final output of attention layer
|
436 |
+
# out[1] - attention probability matrices
|
437 |
+
if name in n_maps:
|
438 |
+
n_maps[name] += 1
|
439 |
+
else:
|
440 |
+
n_maps[name] = 1
|
441 |
+
if 'attn2' in name:
|
442 |
+
assert out[1][0].shape[-1] == 77
|
443 |
+
if name in CrossAttentionLayers and n_maps[name] > 10:
|
444 |
+
if name in crossattn_maps:
|
445 |
+
crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
|
446 |
+
else:
|
447 |
+
crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
|
448 |
+
else:
|
449 |
+
assert out[1][0].shape[-1] != 77
|
450 |
+
if name in SelfAttentionLayers and n_maps[name] > 10:
|
451 |
+
if name in crossattn_maps:
|
452 |
+
selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
|
453 |
+
else:
|
454 |
+
selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
|
455 |
+
|
456 |
+
selfattn_maps = collections.defaultdict(list)
|
457 |
+
crossattn_maps = collections.defaultdict(list)
|
458 |
+
n_maps = collections.defaultdict(list)
|
459 |
+
|
460 |
+
for name, module in self.unet.named_modules():
|
461 |
+
leaf_name = name.split('.')[-1]
|
462 |
+
if 'attn' in leaf_name:
|
463 |
+
# Register hook to obtain outputs at every attention layer.
|
464 |
+
self.forward_hooks.append(module.register_forward_hook(
|
465 |
+
partial(save_activations, selfattn_maps,
|
466 |
+
crossattn_maps, n_maps, name)
|
467 |
+
))
|
468 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
469 |
+
self.selfattn_maps = selfattn_maps
|
470 |
+
self.crossattn_maps = crossattn_maps
|
471 |
+
self.n_maps = n_maps
|
472 |
+
|
473 |
+
def remove_tokenmap_hooks(self):
|
474 |
+
for hook in self.forward_hooks:
|
475 |
+
hook.remove()
|
476 |
+
self.selfattn_maps = None
|
477 |
+
self.crossattn_maps = None
|
478 |
+
self.n_maps = None
|
479 |
+
|
480 |
+
def remove_evaluation_hooks(self):
|
481 |
+
for hook in self.forward_hooks:
|
482 |
+
hook.remove()
|
483 |
+
self.attention_maps = None
|
484 |
+
|
485 |
+
def remove_replacement_hooks(self):
|
486 |
+
for hook in self.forward_replacement_hooks:
|
487 |
+
hook.remove()
|
488 |
+
|
489 |
+
def remove_selfattn_hooks(self):
|
490 |
+
for hook in self.selfattn_forward_hooks:
|
491 |
+
hook.remove()
|
492 |
+
|
493 |
+
def register_fontsize_hooks(self, text_format_dict={}):
|
494 |
+
r"""Function for registering hooks to replace self attention.
|
495 |
+
"""
|
496 |
+
self.forward_fontsize_hooks = []
|
497 |
+
|
498 |
+
def adjust_attn_weights(name, module, args):
|
499 |
+
r"""
|
500 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
501 |
+
"""
|
502 |
+
if 'attn2' in name:
|
503 |
+
modified_args = (args[0], None, attn_weights)
|
504 |
+
return modified_args
|
505 |
+
|
506 |
+
if text_format_dict['word_pos'] is not None and text_format_dict['font_size'] is not None:
|
507 |
+
attn_weights = {'word_pos': text_format_dict['word_pos'], 'font_size': text_format_dict['font_size']}
|
508 |
+
else:
|
509 |
+
attn_weights = None
|
510 |
+
|
511 |
+
for name, module in self.unet.named_modules():
|
512 |
+
leaf_name = name.split('.')[-1]
|
513 |
+
if 'attn' in leaf_name and attn_weights is not None:
|
514 |
+
# Register hook to obtain outputs at every attention layer.
|
515 |
+
self.forward_fontsize_hooks.append(module.register_forward_pre_hook(
|
516 |
+
partial(adjust_attn_weights, name)
|
517 |
+
))
|
518 |
+
|
519 |
+
def remove_fontsize_hooks(self):
|
520 |
+
for hook in self.forward_fontsize_hooks:
|
521 |
+
hook.remove()
|
models/region_diffusion_xl.py
ADDED
@@ -0,0 +1,1143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from diffusers.pipelines.stable_diffusion.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
8 |
+
|
9 |
+
from diffusers.image_processor import VaeImageProcessor
|
10 |
+
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
11 |
+
# from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
12 |
+
from diffusers.models import AutoencoderKL
|
13 |
+
|
14 |
+
from diffusers.models.attention_processor import (
|
15 |
+
AttnProcessor2_0,
|
16 |
+
LoRAAttnProcessor2_0,
|
17 |
+
LoRAXFormersAttnProcessor,
|
18 |
+
XFormersAttnProcessor,
|
19 |
+
)
|
20 |
+
from diffusers.schedulers import EulerDiscreteScheduler
|
21 |
+
from diffusers.utils import (
|
22 |
+
is_accelerate_available,
|
23 |
+
is_accelerate_version,
|
24 |
+
logging,
|
25 |
+
randn_tensor,
|
26 |
+
replace_example_docstring,
|
27 |
+
)
|
28 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
29 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
30 |
+
|
31 |
+
### cutomized modules
|
32 |
+
import collections
|
33 |
+
from functools import partial
|
34 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
35 |
+
|
36 |
+
from models.unet_2d_condition import UNet2DConditionModel
|
37 |
+
from utils.attention_utils import CrossAttentionLayers_XL
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
|
42 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
43 |
+
"""
|
44 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
45 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
46 |
+
"""
|
47 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
48 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
49 |
+
# rescale the results from guidance (fixes overexposure)
|
50 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
51 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
52 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
53 |
+
return noise_cfg
|
54 |
+
|
55 |
+
|
56 |
+
class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
|
57 |
+
r"""
|
58 |
+
Pipeline for text-to-image generation using Stable Diffusion.
|
59 |
+
|
60 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
61 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
62 |
+
|
63 |
+
In addition the pipeline inherits the following loading methods:
|
64 |
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
65 |
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
66 |
+
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
67 |
+
|
68 |
+
as well as the following saving methods:
|
69 |
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
70 |
+
|
71 |
+
Args:
|
72 |
+
vae ([`AutoencoderKL`]):
|
73 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
74 |
+
text_encoder ([`CLIPTextModel`]):
|
75 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
76 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
77 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
78 |
+
tokenizer (`CLIPTokenizer`):
|
79 |
+
Tokenizer of class
|
80 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
81 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
82 |
+
scheduler ([`SchedulerMixin`]):
|
83 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
84 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
load_path: str = "stabilityai/stable-diffusion-xl-base-1.0",
|
90 |
+
device: str = "cuda",
|
91 |
+
force_zeros_for_empty_prompt: bool = True,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
# self.register_modules(
|
96 |
+
# vae=vae,
|
97 |
+
# text_encoder=text_encoder,
|
98 |
+
# text_encoder_2=text_encoder_2,
|
99 |
+
# tokenizer=tokenizer,
|
100 |
+
# tokenizer_2=tokenizer_2,
|
101 |
+
# unet=unet,
|
102 |
+
# scheduler=scheduler,
|
103 |
+
# )
|
104 |
+
|
105 |
+
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
106 |
+
self.vae = AutoencoderKL.from_pretrained(load_path, subfolder="vae", torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
107 |
+
|
108 |
+
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
109 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(load_path, subfolder='tokenizer')
|
110 |
+
self.tokenizer_2 = CLIPTokenizer.from_pretrained(load_path, subfolder='tokenizer_2')
|
111 |
+
self.text_encoder = CLIPTextModel.from_pretrained(load_path, subfolder='text_encoder', torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
112 |
+
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(load_path, subfolder='text_encoder_2', torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
113 |
+
|
114 |
+
# 3. The UNet model for generating the latents.
|
115 |
+
self.unet = UNet2DConditionModel.from_pretrained(load_path, subfolder="unet", torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
116 |
+
|
117 |
+
# 4. Scheduler.
|
118 |
+
self.scheduler = EulerDiscreteScheduler.from_pretrained(load_path, subfolder="scheduler")
|
119 |
+
|
120 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
121 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
122 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
123 |
+
self.default_sample_size = self.unet.config.sample_size
|
124 |
+
|
125 |
+
self.watermark = StableDiffusionXLWatermarker()
|
126 |
+
|
127 |
+
self.device_type = device
|
128 |
+
|
129 |
+
self.masks = []
|
130 |
+
self.attention_maps = None
|
131 |
+
self.selfattn_maps = None
|
132 |
+
self.crossattn_maps = None
|
133 |
+
self.color_loss = torch.nn.functional.mse_loss
|
134 |
+
self.forward_hooks = []
|
135 |
+
self.forward_replacement_hooks = []
|
136 |
+
|
137 |
+
# Overwriting the method from diffusers.pipelines.diffusion_pipeline.DiffusionPipeline
|
138 |
+
@property
|
139 |
+
def device(self) -> torch.device:
|
140 |
+
r"""
|
141 |
+
Returns:
|
142 |
+
`torch.device`: The torch device on which the pipeline is located.
|
143 |
+
"""
|
144 |
+
|
145 |
+
return torch.device(self.device_type)
|
146 |
+
|
147 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
148 |
+
def enable_vae_slicing(self):
|
149 |
+
r"""
|
150 |
+
Enable sliced VAE decoding.
|
151 |
+
|
152 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
153 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
154 |
+
"""
|
155 |
+
self.vae.enable_slicing()
|
156 |
+
|
157 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
158 |
+
def disable_vae_slicing(self):
|
159 |
+
r"""
|
160 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
161 |
+
computing decoding in one step.
|
162 |
+
"""
|
163 |
+
self.vae.disable_slicing()
|
164 |
+
|
165 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
166 |
+
def enable_vae_tiling(self):
|
167 |
+
r"""
|
168 |
+
Enable tiled VAE decoding.
|
169 |
+
|
170 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
171 |
+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
172 |
+
"""
|
173 |
+
self.vae.enable_tiling()
|
174 |
+
|
175 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
176 |
+
def disable_vae_tiling(self):
|
177 |
+
r"""
|
178 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
179 |
+
computing decoding in one step.
|
180 |
+
"""
|
181 |
+
self.vae.disable_tiling()
|
182 |
+
|
183 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
184 |
+
r"""
|
185 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
186 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
187 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
188 |
+
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
189 |
+
`enable_model_cpu_offload`, but performance is lower.
|
190 |
+
"""
|
191 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
192 |
+
from accelerate import cpu_offload
|
193 |
+
else:
|
194 |
+
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
195 |
+
|
196 |
+
device = torch.device(f"cuda:{gpu_id}")
|
197 |
+
|
198 |
+
if self.device.type != "cpu":
|
199 |
+
self.to("cpu", silence_dtype_warnings=True)
|
200 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
201 |
+
|
202 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]:
|
203 |
+
cpu_offload(cpu_offloaded_model, device)
|
204 |
+
|
205 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
206 |
+
r"""
|
207 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
208 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
209 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
210 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
211 |
+
"""
|
212 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
213 |
+
from accelerate import cpu_offload_with_hook
|
214 |
+
else:
|
215 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
216 |
+
|
217 |
+
device = torch.device(f"cuda:{gpu_id}")
|
218 |
+
|
219 |
+
if self.device.type != "cpu":
|
220 |
+
self.to("cpu", silence_dtype_warnings=True)
|
221 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
222 |
+
|
223 |
+
model_sequence = (
|
224 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
225 |
+
)
|
226 |
+
model_sequence.extend([self.unet, self.vae])
|
227 |
+
|
228 |
+
hook = None
|
229 |
+
for cpu_offloaded_model in model_sequence:
|
230 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
231 |
+
|
232 |
+
# We'll offload the last model manually.
|
233 |
+
self.final_offload_hook = hook
|
234 |
+
|
235 |
+
@property
|
236 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
237 |
+
def _execution_device(self):
|
238 |
+
r"""
|
239 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
240 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
241 |
+
hooks.
|
242 |
+
"""
|
243 |
+
if not hasattr(self.unet, "_hf_hook"):
|
244 |
+
return self.device
|
245 |
+
for module in self.unet.modules():
|
246 |
+
if (
|
247 |
+
hasattr(module, "_hf_hook")
|
248 |
+
and hasattr(module._hf_hook, "execution_device")
|
249 |
+
and module._hf_hook.execution_device is not None
|
250 |
+
):
|
251 |
+
return torch.device(module._hf_hook.execution_device)
|
252 |
+
return self.device
|
253 |
+
|
254 |
+
def encode_prompt(
|
255 |
+
self,
|
256 |
+
prompt,
|
257 |
+
device: Optional[torch.device] = None,
|
258 |
+
num_images_per_prompt: int = 1,
|
259 |
+
do_classifier_free_guidance: bool = True,
|
260 |
+
negative_prompt=None,
|
261 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
262 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
263 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
264 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
265 |
+
lora_scale: Optional[float] = None,
|
266 |
+
):
|
267 |
+
r"""
|
268 |
+
Encodes the prompt into text encoder hidden states.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
prompt (`str` or `List[str]`, *optional*):
|
272 |
+
prompt to be encoded
|
273 |
+
device: (`torch.device`):
|
274 |
+
torch device
|
275 |
+
num_images_per_prompt (`int`):
|
276 |
+
number of images that should be generated per prompt
|
277 |
+
do_classifier_free_guidance (`bool`):
|
278 |
+
whether to use classifier free guidance or not
|
279 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
280 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
281 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
282 |
+
less than `1`).
|
283 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
284 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
285 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
286 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
287 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
288 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
289 |
+
argument.
|
290 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
291 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
292 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
293 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
294 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
295 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
296 |
+
input argument.
|
297 |
+
lora_scale (`float`, *optional*):
|
298 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
299 |
+
"""
|
300 |
+
device = device or self._execution_device
|
301 |
+
|
302 |
+
# set lora scale so that monkey patched LoRA
|
303 |
+
# function of text encoder can correctly access it
|
304 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
305 |
+
self._lora_scale = lora_scale
|
306 |
+
|
307 |
+
if prompt is not None and isinstance(prompt, str):
|
308 |
+
batch_size = 1
|
309 |
+
elif prompt is not None and isinstance(prompt, list):
|
310 |
+
batch_size = len(prompt)
|
311 |
+
batch_size_neg = len(negative_prompt)
|
312 |
+
else:
|
313 |
+
batch_size = prompt_embeds.shape[0]
|
314 |
+
|
315 |
+
# Define tokenizers and text encoders
|
316 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
317 |
+
text_encoders = (
|
318 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
319 |
+
)
|
320 |
+
|
321 |
+
if prompt_embeds is None:
|
322 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
323 |
+
prompt_embeds_list = []
|
324 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
325 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
326 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
327 |
+
|
328 |
+
text_inputs = tokenizer(
|
329 |
+
prompt,
|
330 |
+
padding="max_length",
|
331 |
+
max_length=tokenizer.model_max_length,
|
332 |
+
truncation=True,
|
333 |
+
return_tensors="pt",
|
334 |
+
)
|
335 |
+
text_input_ids = text_inputs.input_ids
|
336 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
337 |
+
|
338 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
339 |
+
text_input_ids, untruncated_ids
|
340 |
+
):
|
341 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
342 |
+
logger.warning(
|
343 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
344 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
345 |
+
)
|
346 |
+
|
347 |
+
prompt_embeds = text_encoder(
|
348 |
+
text_input_ids.to(device),
|
349 |
+
output_hidden_states=True,
|
350 |
+
)
|
351 |
+
|
352 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
353 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
354 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
355 |
+
|
356 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
357 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
358 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
359 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
360 |
+
|
361 |
+
prompt_embeds_list.append(prompt_embeds)
|
362 |
+
|
363 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
364 |
+
|
365 |
+
# get unconditional embeddings for classifier free guidance
|
366 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
367 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
368 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
369 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
370 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
371 |
+
negative_prompt = negative_prompt or ""
|
372 |
+
uncond_tokens: List[str]
|
373 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
374 |
+
raise TypeError(
|
375 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
376 |
+
f" {type(prompt)}."
|
377 |
+
)
|
378 |
+
elif isinstance(negative_prompt, str):
|
379 |
+
uncond_tokens = [negative_prompt]
|
380 |
+
# elif batch_size != len(negative_prompt):
|
381 |
+
# raise ValueError(
|
382 |
+
# f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
383 |
+
# f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
384 |
+
# " the batch size of `prompt`."
|
385 |
+
# )
|
386 |
+
else:
|
387 |
+
uncond_tokens = negative_prompt
|
388 |
+
|
389 |
+
negative_prompt_embeds_list = []
|
390 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
391 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
392 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
393 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
|
394 |
+
|
395 |
+
max_length = prompt_embeds.shape[1]
|
396 |
+
uncond_input = tokenizer(
|
397 |
+
uncond_tokens,
|
398 |
+
padding="max_length",
|
399 |
+
max_length=max_length,
|
400 |
+
truncation=True,
|
401 |
+
return_tensors="pt",
|
402 |
+
)
|
403 |
+
|
404 |
+
negative_prompt_embeds = text_encoder(
|
405 |
+
uncond_input.input_ids.to(device),
|
406 |
+
output_hidden_states=True,
|
407 |
+
)
|
408 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
409 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
410 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
411 |
+
|
412 |
+
if do_classifier_free_guidance:
|
413 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
414 |
+
seq_len = negative_prompt_embeds.shape[1]
|
415 |
+
|
416 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
417 |
+
|
418 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
419 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
420 |
+
batch_size_neg * num_images_per_prompt, seq_len, -1
|
421 |
+
)
|
422 |
+
|
423 |
+
# For classifier free guidance, we need to do two forward passes.
|
424 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
425 |
+
# to avoid doing two forward passes
|
426 |
+
|
427 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
428 |
+
|
429 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
430 |
+
|
431 |
+
bs_embed = pooled_prompt_embeds.shape[0]
|
432 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
433 |
+
bs_embed * num_images_per_prompt, -1
|
434 |
+
)
|
435 |
+
bs_embed = negative_pooled_prompt_embeds.shape[0]
|
436 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
437 |
+
bs_embed * num_images_per_prompt, -1
|
438 |
+
)
|
439 |
+
|
440 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
441 |
+
|
442 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
443 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
444 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
445 |
+
# eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
446 |
+
# eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
|
447 |
+
# and should be between [0, 1]
|
448 |
+
|
449 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
450 |
+
extra_step_kwargs = {}
|
451 |
+
if accepts_eta:
|
452 |
+
extra_step_kwargs["eta"] = eta
|
453 |
+
|
454 |
+
# check if the scheduler accepts generator
|
455 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
456 |
+
if accepts_generator:
|
457 |
+
extra_step_kwargs["generator"] = generator
|
458 |
+
return extra_step_kwargs
|
459 |
+
|
460 |
+
def check_inputs(
|
461 |
+
self,
|
462 |
+
prompt,
|
463 |
+
height,
|
464 |
+
width,
|
465 |
+
callback_steps,
|
466 |
+
negative_prompt=None,
|
467 |
+
prompt_embeds=None,
|
468 |
+
negative_prompt_embeds=None,
|
469 |
+
pooled_prompt_embeds=None,
|
470 |
+
negative_pooled_prompt_embeds=None,
|
471 |
+
):
|
472 |
+
if height % 8 != 0 or width % 8 != 0:
|
473 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
474 |
+
|
475 |
+
if (callback_steps is None) or (
|
476 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
477 |
+
):
|
478 |
+
raise ValueError(
|
479 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
480 |
+
f" {type(callback_steps)}."
|
481 |
+
)
|
482 |
+
|
483 |
+
if prompt is not None and prompt_embeds is not None:
|
484 |
+
raise ValueError(
|
485 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
486 |
+
" only forward one of the two."
|
487 |
+
)
|
488 |
+
elif prompt is None and prompt_embeds is None:
|
489 |
+
raise ValueError(
|
490 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
491 |
+
)
|
492 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
493 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
494 |
+
|
495 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
496 |
+
raise ValueError(
|
497 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
498 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
499 |
+
)
|
500 |
+
|
501 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
502 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
503 |
+
raise ValueError(
|
504 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
505 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
506 |
+
f" {negative_prompt_embeds.shape}."
|
507 |
+
)
|
508 |
+
|
509 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
510 |
+
raise ValueError(
|
511 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
512 |
+
)
|
513 |
+
|
514 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
515 |
+
raise ValueError(
|
516 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
517 |
+
)
|
518 |
+
|
519 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
520 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
521 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
522 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
523 |
+
raise ValueError(
|
524 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
525 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
526 |
+
)
|
527 |
+
|
528 |
+
if latents is None:
|
529 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
530 |
+
else:
|
531 |
+
latents = latents.to(device)
|
532 |
+
|
533 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
534 |
+
latents = latents * self.scheduler.init_noise_sigma
|
535 |
+
return latents
|
536 |
+
|
537 |
+
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
538 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
539 |
+
|
540 |
+
passed_add_embed_dim = (
|
541 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
542 |
+
)
|
543 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
544 |
+
|
545 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
546 |
+
raise ValueError(
|
547 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
548 |
+
)
|
549 |
+
|
550 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
551 |
+
return add_time_ids
|
552 |
+
|
553 |
+
@torch.no_grad()
|
554 |
+
def sample(
|
555 |
+
self,
|
556 |
+
prompt: Union[str, List[str]] = None,
|
557 |
+
height: Optional[int] = None,
|
558 |
+
width: Optional[int] = None,
|
559 |
+
num_inference_steps: int = 50,
|
560 |
+
guidance_scale: float = 5.0,
|
561 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
562 |
+
num_images_per_prompt: Optional[int] = 1,
|
563 |
+
eta: float = 0.0,
|
564 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
565 |
+
latents: Optional[torch.FloatTensor] = None,
|
566 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
567 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
568 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
569 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
570 |
+
output_type: Optional[str] = "pil",
|
571 |
+
return_dict: bool = True,
|
572 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
573 |
+
callback_steps: int = 1,
|
574 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
575 |
+
guidance_rescale: float = 0.0,
|
576 |
+
original_size: Optional[Tuple[int, int]] = None,
|
577 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
578 |
+
target_size: Optional[Tuple[int, int]] = None,
|
579 |
+
# Rich-Text args
|
580 |
+
use_guidance: bool = False,
|
581 |
+
inject_selfattn: float = 0.0,
|
582 |
+
inject_background: float = 0.0,
|
583 |
+
text_format_dict: Optional[dict] = None,
|
584 |
+
run_rich_text: bool = False,
|
585 |
+
):
|
586 |
+
r"""
|
587 |
+
Function invoked when calling the pipeline for generation.
|
588 |
+
|
589 |
+
Args:
|
590 |
+
prompt (`str` or `List[str]`, *optional*):
|
591 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
592 |
+
instead.
|
593 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
594 |
+
The height in pixels of the generated image.
|
595 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
596 |
+
The width in pixels of the generated image.
|
597 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
598 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
599 |
+
expense of slower inference.
|
600 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
601 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
602 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
603 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
604 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
605 |
+
usually at the expense of lower image quality.
|
606 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
607 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
608 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
609 |
+
less than `1`).
|
610 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
611 |
+
The number of images to generate per prompt.
|
612 |
+
eta (`float`, *optional*, defaults to 0.0):
|
613 |
+
Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
614 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
615 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
616 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
617 |
+
to make generation deterministic.
|
618 |
+
latents (`torch.FloatTensor`, *optional*):
|
619 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
620 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
621 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
622 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
623 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
624 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
625 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
626 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
627 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
628 |
+
argument.
|
629 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
630 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
631 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
632 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
633 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
634 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
635 |
+
input argument.
|
636 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
637 |
+
The output format of the generate image. Choose between
|
638 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
639 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
640 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
|
641 |
+
plain tuple.
|
642 |
+
callback (`Callable`, *optional*):
|
643 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
644 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
645 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
646 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
647 |
+
called at every step.
|
648 |
+
cross_attention_kwargs (`dict`, *optional*):
|
649 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
650 |
+
`self.processor` in
|
651 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
652 |
+
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
653 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
654 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `Ο` in equation 16. of
|
655 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
656 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
657 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
658 |
+
TODO
|
659 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
660 |
+
TODO
|
661 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
662 |
+
TODO
|
663 |
+
|
664 |
+
Examples:
|
665 |
+
|
666 |
+
Returns:
|
667 |
+
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
668 |
+
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
669 |
+
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
670 |
+
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
671 |
+
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
672 |
+
"""
|
673 |
+
# 0. Default height and width to unet
|
674 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
675 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
676 |
+
|
677 |
+
original_size = original_size or (height, width)
|
678 |
+
target_size = target_size or (height, width)
|
679 |
+
|
680 |
+
# 1. Check inputs. Raise error if not correct
|
681 |
+
self.check_inputs(
|
682 |
+
prompt,
|
683 |
+
height,
|
684 |
+
width,
|
685 |
+
callback_steps,
|
686 |
+
negative_prompt,
|
687 |
+
prompt_embeds,
|
688 |
+
negative_prompt_embeds,
|
689 |
+
pooled_prompt_embeds,
|
690 |
+
negative_pooled_prompt_embeds,
|
691 |
+
)
|
692 |
+
|
693 |
+
# 2. Define call parameters
|
694 |
+
if prompt is not None and isinstance(prompt, str):
|
695 |
+
batch_size = 1
|
696 |
+
elif prompt is not None and isinstance(prompt, list):
|
697 |
+
# TODO: support batched prompts
|
698 |
+
batch_size = 1
|
699 |
+
# batch_size = len(prompt)
|
700 |
+
else:
|
701 |
+
batch_size = prompt_embeds.shape[0]
|
702 |
+
|
703 |
+
device = self._execution_device
|
704 |
+
|
705 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
706 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
707 |
+
# corresponds to doing no classifier free guidance.
|
708 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
709 |
+
|
710 |
+
# 3. Encode input prompt
|
711 |
+
text_encoder_lora_scale = (
|
712 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
713 |
+
)
|
714 |
+
(
|
715 |
+
prompt_embeds,
|
716 |
+
negative_prompt_embeds,
|
717 |
+
pooled_prompt_embeds,
|
718 |
+
negative_pooled_prompt_embeds,
|
719 |
+
) = self.encode_prompt(
|
720 |
+
prompt,
|
721 |
+
device,
|
722 |
+
num_images_per_prompt,
|
723 |
+
do_classifier_free_guidance,
|
724 |
+
negative_prompt,
|
725 |
+
prompt_embeds=prompt_embeds,
|
726 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
727 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
728 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
729 |
+
lora_scale=text_encoder_lora_scale,
|
730 |
+
)
|
731 |
+
|
732 |
+
# 4. Prepare timesteps
|
733 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
734 |
+
|
735 |
+
timesteps = self.scheduler.timesteps
|
736 |
+
|
737 |
+
# 5. Prepare latent variables
|
738 |
+
num_channels_latents = self.unet.config.in_channels
|
739 |
+
latents = self.prepare_latents(
|
740 |
+
batch_size * num_images_per_prompt,
|
741 |
+
num_channels_latents,
|
742 |
+
height,
|
743 |
+
width,
|
744 |
+
prompt_embeds.dtype,
|
745 |
+
device,
|
746 |
+
generator,
|
747 |
+
latents,
|
748 |
+
)
|
749 |
+
|
750 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
751 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
752 |
+
|
753 |
+
# 7. Prepare added time ids & embeddings
|
754 |
+
add_text_embeds = pooled_prompt_embeds
|
755 |
+
add_time_ids = self._get_add_time_ids(
|
756 |
+
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
757 |
+
)
|
758 |
+
|
759 |
+
if do_classifier_free_guidance:
|
760 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
761 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
762 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
763 |
+
|
764 |
+
prompt_embeds = prompt_embeds.to(device)
|
765 |
+
add_text_embeds = add_text_embeds.to(device)
|
766 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
767 |
+
|
768 |
+
# 8. Denoising loop
|
769 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
770 |
+
if run_rich_text:
|
771 |
+
if inject_selfattn > 0 or inject_background > 0:
|
772 |
+
latents_reference = latents.clone().detach()
|
773 |
+
n_styles = prompt_embeds.shape[0]-1
|
774 |
+
self.masks = [mask.to(dtype=prompt_embeds.dtype) for mask in self.masks]
|
775 |
+
print(n_styles, len(self.masks))
|
776 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
777 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
778 |
+
# predict the noise residual
|
779 |
+
with torch.no_grad():
|
780 |
+
feat_inject_step = t > (1-inject_selfattn) * 1000
|
781 |
+
background_inject_step = i < inject_background * len(self.scheduler.timesteps)
|
782 |
+
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
783 |
+
# import ipdb;ipdb.set_trace()
|
784 |
+
# unconditional prediction
|
785 |
+
noise_pred_uncond_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[:1],
|
786 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
787 |
+
added_cond_kwargs={"text_embeds": add_text_embeds[:1], "time_ids": add_time_ids[:1]}
|
788 |
+
)['sample']
|
789 |
+
# tokens without any style or footnote
|
790 |
+
self.register_fontsize_hooks(text_format_dict)
|
791 |
+
noise_pred_text_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[-1:],
|
792 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
793 |
+
added_cond_kwargs={"text_embeds": add_text_embeds[-1:], "time_ids": add_time_ids[:1]}
|
794 |
+
)['sample']
|
795 |
+
self.remove_fontsize_hooks()
|
796 |
+
if inject_selfattn > 0 or inject_background > 0:
|
797 |
+
latent_reference_model_input = self.scheduler.scale_model_input(latents_reference, t)
|
798 |
+
noise_pred_uncond_refer = self.unet(latent_reference_model_input, t, encoder_hidden_states=prompt_embeds[:1],
|
799 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
800 |
+
added_cond_kwargs={"text_embeds": add_text_embeds[:1], "time_ids": add_time_ids[:1]}
|
801 |
+
)['sample']
|
802 |
+
self.register_selfattn_hooks(feat_inject_step)
|
803 |
+
noise_pred_text_refer = self.unet(latent_reference_model_input, t, encoder_hidden_states=prompt_embeds[-1:],
|
804 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
805 |
+
added_cond_kwargs={"text_embeds": add_text_embeds[-1:], "time_ids": add_time_ids[:1]}
|
806 |
+
)['sample']
|
807 |
+
self.remove_selfattn_hooks()
|
808 |
+
noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
|
809 |
+
noise_pred_text = noise_pred_text_cur * self.masks[-1]
|
810 |
+
# tokens with style or footnote
|
811 |
+
for style_i, mask in enumerate(self.masks[:-1]):
|
812 |
+
self.register_replacement_hooks(feat_inject_step)
|
813 |
+
noise_pred_text_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[style_i+1:style_i+2],
|
814 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
815 |
+
added_cond_kwargs={"text_embeds": add_text_embeds[style_i+1:style_i+2], "time_ids": add_time_ids[:1]}
|
816 |
+
)['sample']
|
817 |
+
self.remove_replacement_hooks()
|
818 |
+
noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
|
819 |
+
noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
|
820 |
+
|
821 |
+
# perform guidance
|
822 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
823 |
+
(noise_pred_text - noise_pred_uncond)
|
824 |
+
|
825 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
826 |
+
# TODO: Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
827 |
+
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
828 |
+
raise NotImplementedError
|
829 |
+
|
830 |
+
if inject_selfattn > 0 or background_inject_step > 0:
|
831 |
+
noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
|
832 |
+
(noise_pred_text_refer - noise_pred_uncond_refer)
|
833 |
+
|
834 |
+
# compute the previous noisy sample x_t -> x_t-1
|
835 |
+
latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
|
836 |
+
torch.cat([latents, latents_reference]))[
|
837 |
+
'prev_sample']
|
838 |
+
latents, latents_reference = torch.chunk(
|
839 |
+
latents_reference, 2, dim=0)
|
840 |
+
|
841 |
+
else:
|
842 |
+
# compute the previous noisy sample x_t -> x_t-1
|
843 |
+
latents = self.scheduler.step(noise_pred, t, latents)[
|
844 |
+
'prev_sample']
|
845 |
+
|
846 |
+
# apply guidance
|
847 |
+
if use_guidance and t < text_format_dict['guidance_start_step']:
|
848 |
+
with torch.enable_grad():
|
849 |
+
self.unet.to(device='cpu')
|
850 |
+
torch.cuda.empty_cache()
|
851 |
+
if not latents.requires_grad:
|
852 |
+
latents.requires_grad = True
|
853 |
+
# import ipdb;ipdb.set_trace()
|
854 |
+
# latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=latents.dtype)
|
855 |
+
latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=torch.bfloat16)
|
856 |
+
latents_inp = latents_0 / self.vae.config.scaling_factor
|
857 |
+
imgs = self.vae.to(dtype=latents_inp.dtype).decode(latents_inp).sample
|
858 |
+
# imgs = self.vae.decode(latents_inp.to(dtype=torch.float32)).sample
|
859 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
860 |
+
loss_total = 0.
|
861 |
+
for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
|
862 |
+
avg_rgb = (
|
863 |
+
imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
|
864 |
+
loss = self.color_loss(
|
865 |
+
avg_rgb, rgb_val[:, :, 0, 0])*100
|
866 |
+
loss_total += loss
|
867 |
+
loss_total.backward()
|
868 |
+
latents = (
|
869 |
+
latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone().to(dtype=prompt_embeds.dtype)
|
870 |
+
self.unet.to(device=latents.device)
|
871 |
+
|
872 |
+
# apply background injection
|
873 |
+
if i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0:
|
874 |
+
latents = latents_reference * self.masks[-1] + latents * \
|
875 |
+
(1-self.masks[-1])
|
876 |
+
|
877 |
+
# call the callback, if provided
|
878 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
879 |
+
progress_bar.update()
|
880 |
+
if callback is not None and i % callback_steps == 0:
|
881 |
+
callback(i, t, latents)
|
882 |
+
else:
|
883 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
884 |
+
for i, t in enumerate(timesteps):
|
885 |
+
# expand the latents if we are doing classifier free guidance
|
886 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
887 |
+
|
888 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
889 |
+
|
890 |
+
# predict the noise residual
|
891 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
892 |
+
noise_pred = self.unet(
|
893 |
+
latent_model_input,
|
894 |
+
t,
|
895 |
+
encoder_hidden_states=prompt_embeds,
|
896 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
897 |
+
added_cond_kwargs=added_cond_kwargs,
|
898 |
+
return_dict=False,
|
899 |
+
)[0]
|
900 |
+
|
901 |
+
# perform guidance
|
902 |
+
if do_classifier_free_guidance:
|
903 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
904 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
905 |
+
|
906 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
907 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
908 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
909 |
+
|
910 |
+
# compute the previous noisy sample x_t -> x_t-1
|
911 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
912 |
+
|
913 |
+
# call the callback, if provided
|
914 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
915 |
+
progress_bar.update()
|
916 |
+
if callback is not None and i % callback_steps == 0:
|
917 |
+
callback(i, t, latents)
|
918 |
+
|
919 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
920 |
+
self.vae.to(dtype=torch.float32)
|
921 |
+
|
922 |
+
use_torch_2_0_or_xformers = isinstance(
|
923 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
924 |
+
(
|
925 |
+
AttnProcessor2_0,
|
926 |
+
XFormersAttnProcessor,
|
927 |
+
LoRAXFormersAttnProcessor,
|
928 |
+
LoRAAttnProcessor2_0,
|
929 |
+
),
|
930 |
+
)
|
931 |
+
# if xformers or torch_2_0 is used attention block does not need
|
932 |
+
# to be in float32 which can save lots of memory
|
933 |
+
if use_torch_2_0_or_xformers:
|
934 |
+
self.vae.post_quant_conv.to(latents.dtype)
|
935 |
+
self.vae.decoder.conv_in.to(latents.dtype)
|
936 |
+
self.vae.decoder.mid_block.to(latents.dtype)
|
937 |
+
else:
|
938 |
+
latents = latents.float()
|
939 |
+
|
940 |
+
if not output_type == "latent":
|
941 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
942 |
+
else:
|
943 |
+
image = latents
|
944 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
945 |
+
|
946 |
+
image = self.watermark.apply_watermark(image)
|
947 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
948 |
+
|
949 |
+
# Offload last model to CPU
|
950 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
951 |
+
self.final_offload_hook.offload()
|
952 |
+
|
953 |
+
if not return_dict:
|
954 |
+
return (image,)
|
955 |
+
|
956 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
957 |
+
|
958 |
+
def predict_x0(self, x_t, eps_t, t):
|
959 |
+
alpha_t = self.scheduler.alphas_cumprod[t.cpu().long().item()]
|
960 |
+
return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
|
961 |
+
|
962 |
+
def register_tokenmap_hooks(self):
|
963 |
+
r"""Function for registering hooks during evaluation.
|
964 |
+
We mainly store activation maps averaged over queries.
|
965 |
+
"""
|
966 |
+
self.forward_hooks = []
|
967 |
+
|
968 |
+
def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
|
969 |
+
r"""
|
970 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
971 |
+
"""
|
972 |
+
# out[0] - final output of attention layer
|
973 |
+
# out[1] - attention probability matrices
|
974 |
+
if name in n_maps:
|
975 |
+
n_maps[name] += 1
|
976 |
+
else:
|
977 |
+
n_maps[name] = 1
|
978 |
+
if 'attn2' in name:
|
979 |
+
assert out[1][0].shape[-1] == 77
|
980 |
+
if name in CrossAttentionLayers_XL and n_maps[name] > 10:
|
981 |
+
# if n_maps[name] > 10:
|
982 |
+
if name in crossattn_maps:
|
983 |
+
crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
|
984 |
+
else:
|
985 |
+
crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
|
986 |
+
# For visualization
|
987 |
+
# crossattn_maps[name].append(out[1][0].detach().cpu()[1:2])
|
988 |
+
else:
|
989 |
+
assert out[1][0].shape[-1] != 77
|
990 |
+
# if name in SelfAttentionLayers and n_maps[name] > 10:
|
991 |
+
if n_maps[name] > 10:
|
992 |
+
if name in selfattn_maps:
|
993 |
+
selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
|
994 |
+
else:
|
995 |
+
selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
|
996 |
+
|
997 |
+
selfattn_maps = collections.defaultdict(list)
|
998 |
+
crossattn_maps = collections.defaultdict(list)
|
999 |
+
n_maps = collections.defaultdict(list)
|
1000 |
+
|
1001 |
+
for name, module in self.unet.named_modules():
|
1002 |
+
leaf_name = name.split('.')[-1]
|
1003 |
+
if 'attn' in leaf_name:
|
1004 |
+
# Register hook to obtain outputs at every attention layer.
|
1005 |
+
self.forward_hooks.append(module.register_forward_hook(
|
1006 |
+
partial(save_activations, selfattn_maps,
|
1007 |
+
crossattn_maps, n_maps, name)
|
1008 |
+
))
|
1009 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
1010 |
+
self.selfattn_maps = selfattn_maps
|
1011 |
+
self.crossattn_maps = crossattn_maps
|
1012 |
+
self.n_maps = n_maps
|
1013 |
+
|
1014 |
+
def remove_tokenmap_hooks(self):
|
1015 |
+
for hook in self.forward_hooks:
|
1016 |
+
hook.remove()
|
1017 |
+
self.selfattn_maps = None
|
1018 |
+
self.crossattn_maps = None
|
1019 |
+
self.n_maps = None
|
1020 |
+
|
1021 |
+
def register_replacement_hooks(self, feat_inject_step=False):
|
1022 |
+
r"""Function for registering hooks to replace self attention.
|
1023 |
+
"""
|
1024 |
+
self.forward_replacement_hooks = []
|
1025 |
+
|
1026 |
+
def replace_activations(name, module, args):
|
1027 |
+
r"""
|
1028 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
1029 |
+
"""
|
1030 |
+
if 'attn1' in name:
|
1031 |
+
modified_args = (args[0], self.self_attention_maps_cur[name].to(args[0].device))
|
1032 |
+
return modified_args
|
1033 |
+
# cross attention injection
|
1034 |
+
# elif 'attn2' in name:
|
1035 |
+
# modified_map = {
|
1036 |
+
# 'reference': self.self_attention_maps_cur[name],
|
1037 |
+
# 'inject_pos': self.inject_pos,
|
1038 |
+
# }
|
1039 |
+
# modified_args = (args[0], modified_map)
|
1040 |
+
# return modified_args
|
1041 |
+
|
1042 |
+
def replace_resnet_activations(name, module, args):
|
1043 |
+
r"""
|
1044 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
1045 |
+
"""
|
1046 |
+
modified_args = (args[0], args[1],
|
1047 |
+
self.self_attention_maps_cur[name].to(args[0].device))
|
1048 |
+
return modified_args
|
1049 |
+
for name, module in self.unet.named_modules():
|
1050 |
+
leaf_name = name.split('.')[-1]
|
1051 |
+
if 'attn' in leaf_name and feat_inject_step:
|
1052 |
+
# Register hook to obtain outputs at every attention layer.
|
1053 |
+
self.forward_replacement_hooks.append(module.register_forward_pre_hook(
|
1054 |
+
partial(replace_activations, name)
|
1055 |
+
))
|
1056 |
+
if name == 'up_blocks.1.resnets.1' and feat_inject_step:
|
1057 |
+
# Register hook to obtain outputs at every attention layer.
|
1058 |
+
self.forward_replacement_hooks.append(module.register_forward_pre_hook(
|
1059 |
+
partial(replace_resnet_activations, name)
|
1060 |
+
))
|
1061 |
+
|
1062 |
+
def remove_replacement_hooks(self):
|
1063 |
+
for hook in self.forward_replacement_hooks:
|
1064 |
+
hook.remove()
|
1065 |
+
|
1066 |
+
|
1067 |
+
def register_selfattn_hooks(self, feat_inject_step=False):
|
1068 |
+
r"""Function for registering hooks during evaluation.
|
1069 |
+
We mainly store activation maps averaged over queries.
|
1070 |
+
"""
|
1071 |
+
self.selfattn_forward_hooks = []
|
1072 |
+
|
1073 |
+
def save_activations(activations, name, module, inp, out):
|
1074 |
+
r"""
|
1075 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
1076 |
+
"""
|
1077 |
+
# out[0] - final output of attention layer
|
1078 |
+
# out[1] - attention probability matrix
|
1079 |
+
if 'attn2' in name:
|
1080 |
+
assert out[1][1].shape[-1] == 77
|
1081 |
+
# cross attention injection
|
1082 |
+
# activations[name] = out[1][1].detach()
|
1083 |
+
else:
|
1084 |
+
assert out[1][1].shape[-1] != 77
|
1085 |
+
activations[name] = out[1][1].detach().cpu()
|
1086 |
+
|
1087 |
+
def save_resnet_activations(activations, name, module, inp, out):
|
1088 |
+
r"""
|
1089 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
1090 |
+
"""
|
1091 |
+
# out[0] - final output of residual layer
|
1092 |
+
# out[1] - residual hidden feature
|
1093 |
+
# import ipdb;ipdb.set_trace()
|
1094 |
+
assert out[1].shape[-1] == 64
|
1095 |
+
activations[name] = out[1].detach().cpu()
|
1096 |
+
attention_dict = collections.defaultdict(list)
|
1097 |
+
for name, module in self.unet.named_modules():
|
1098 |
+
leaf_name = name.split('.')[-1]
|
1099 |
+
if 'attn' in leaf_name and feat_inject_step:
|
1100 |
+
# Register hook to obtain outputs at every attention layer.
|
1101 |
+
self.selfattn_forward_hooks.append(module.register_forward_hook(
|
1102 |
+
partial(save_activations, attention_dict, name)
|
1103 |
+
))
|
1104 |
+
if name == 'up_blocks.1.resnets.1' and feat_inject_step:
|
1105 |
+
self.selfattn_forward_hooks.append(module.register_forward_hook(
|
1106 |
+
partial(save_resnet_activations, attention_dict, name)
|
1107 |
+
))
|
1108 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
1109 |
+
self.self_attention_maps_cur = attention_dict
|
1110 |
+
|
1111 |
+
def remove_selfattn_hooks(self):
|
1112 |
+
for hook in self.selfattn_forward_hooks:
|
1113 |
+
hook.remove()
|
1114 |
+
|
1115 |
+
def register_fontsize_hooks(self, text_format_dict={}):
|
1116 |
+
r"""Function for registering hooks to replace self attention.
|
1117 |
+
"""
|
1118 |
+
self.forward_fontsize_hooks = []
|
1119 |
+
|
1120 |
+
def adjust_attn_weights(name, module, args):
|
1121 |
+
r"""
|
1122 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
1123 |
+
"""
|
1124 |
+
if 'attn2' in name:
|
1125 |
+
modified_args = (args[0], None, attn_weights)
|
1126 |
+
return modified_args
|
1127 |
+
|
1128 |
+
if text_format_dict['word_pos'] is not None and text_format_dict['font_size'] is not None:
|
1129 |
+
attn_weights = {'word_pos': text_format_dict['word_pos'], 'font_size': text_format_dict['font_size']}
|
1130 |
+
else:
|
1131 |
+
attn_weights = None
|
1132 |
+
|
1133 |
+
for name, module in self.unet.named_modules():
|
1134 |
+
leaf_name = name.split('.')[-1]
|
1135 |
+
if 'attn' in leaf_name and attn_weights is not None:
|
1136 |
+
# Register hook to obtain outputs at every attention layer.
|
1137 |
+
self.forward_fontsize_hooks.append(module.register_forward_pre_hook(
|
1138 |
+
partial(adjust_attn_weights, name)
|
1139 |
+
))
|
1140 |
+
|
1141 |
+
def remove_fontsize_hooks(self):
|
1142 |
+
for hook in self.forward_fontsize_hooks:
|
1143 |
+
hook.remove()
|
models/resnet.py
ADDED
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from diffusers.models.activations import get_activation
|
24 |
+
from diffusers.models.attention import AdaGroupNorm
|
25 |
+
from models.attention_processor import SpatialNorm
|
26 |
+
|
27 |
+
|
28 |
+
class Upsample1D(nn.Module):
|
29 |
+
"""A 1D upsampling layer with an optional convolution.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
channels (`int`):
|
33 |
+
number of channels in the inputs and outputs.
|
34 |
+
use_conv (`bool`, default `False`):
|
35 |
+
option to use a convolution.
|
36 |
+
use_conv_transpose (`bool`, default `False`):
|
37 |
+
option to use a convolution transpose.
|
38 |
+
out_channels (`int`, optional):
|
39 |
+
number of output channels. Defaults to `channels`.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
43 |
+
super().__init__()
|
44 |
+
self.channels = channels
|
45 |
+
self.out_channels = out_channels or channels
|
46 |
+
self.use_conv = use_conv
|
47 |
+
self.use_conv_transpose = use_conv_transpose
|
48 |
+
self.name = name
|
49 |
+
|
50 |
+
self.conv = None
|
51 |
+
if use_conv_transpose:
|
52 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
53 |
+
elif use_conv:
|
54 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
55 |
+
|
56 |
+
def forward(self, inputs):
|
57 |
+
assert inputs.shape[1] == self.channels
|
58 |
+
if self.use_conv_transpose:
|
59 |
+
return self.conv(inputs)
|
60 |
+
|
61 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
62 |
+
|
63 |
+
if self.use_conv:
|
64 |
+
outputs = self.conv(outputs)
|
65 |
+
|
66 |
+
return outputs
|
67 |
+
|
68 |
+
|
69 |
+
class Downsample1D(nn.Module):
|
70 |
+
"""A 1D downsampling layer with an optional convolution.
|
71 |
+
|
72 |
+
Parameters:
|
73 |
+
channels (`int`):
|
74 |
+
number of channels in the inputs and outputs.
|
75 |
+
use_conv (`bool`, default `False`):
|
76 |
+
option to use a convolution.
|
77 |
+
out_channels (`int`, optional):
|
78 |
+
number of output channels. Defaults to `channels`.
|
79 |
+
padding (`int`, default `1`):
|
80 |
+
padding for the convolution.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
84 |
+
super().__init__()
|
85 |
+
self.channels = channels
|
86 |
+
self.out_channels = out_channels or channels
|
87 |
+
self.use_conv = use_conv
|
88 |
+
self.padding = padding
|
89 |
+
stride = 2
|
90 |
+
self.name = name
|
91 |
+
|
92 |
+
if use_conv:
|
93 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
94 |
+
else:
|
95 |
+
assert self.channels == self.out_channels
|
96 |
+
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
97 |
+
|
98 |
+
def forward(self, inputs):
|
99 |
+
assert inputs.shape[1] == self.channels
|
100 |
+
return self.conv(inputs)
|
101 |
+
|
102 |
+
|
103 |
+
class Upsample2D(nn.Module):
|
104 |
+
"""A 2D upsampling layer with an optional convolution.
|
105 |
+
|
106 |
+
Parameters:
|
107 |
+
channels (`int`):
|
108 |
+
number of channels in the inputs and outputs.
|
109 |
+
use_conv (`bool`, default `False`):
|
110 |
+
option to use a convolution.
|
111 |
+
use_conv_transpose (`bool`, default `False`):
|
112 |
+
option to use a convolution transpose.
|
113 |
+
out_channels (`int`, optional):
|
114 |
+
number of output channels. Defaults to `channels`.
|
115 |
+
"""
|
116 |
+
|
117 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
118 |
+
super().__init__()
|
119 |
+
self.channels = channels
|
120 |
+
self.out_channels = out_channels or channels
|
121 |
+
self.use_conv = use_conv
|
122 |
+
self.use_conv_transpose = use_conv_transpose
|
123 |
+
self.name = name
|
124 |
+
|
125 |
+
conv = None
|
126 |
+
if use_conv_transpose:
|
127 |
+
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
128 |
+
elif use_conv:
|
129 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
130 |
+
|
131 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
132 |
+
if name == "conv":
|
133 |
+
self.conv = conv
|
134 |
+
else:
|
135 |
+
self.Conv2d_0 = conv
|
136 |
+
|
137 |
+
def forward(self, hidden_states, output_size=None):
|
138 |
+
assert hidden_states.shape[1] == self.channels
|
139 |
+
|
140 |
+
if self.use_conv_transpose:
|
141 |
+
return self.conv(hidden_states)
|
142 |
+
|
143 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
144 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
145 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
146 |
+
dtype = hidden_states.dtype
|
147 |
+
if dtype == torch.bfloat16:
|
148 |
+
hidden_states = hidden_states.to(torch.float32)
|
149 |
+
|
150 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
151 |
+
if hidden_states.shape[0] >= 64:
|
152 |
+
hidden_states = hidden_states.contiguous()
|
153 |
+
|
154 |
+
# if `output_size` is passed we force the interpolation output
|
155 |
+
# size and do not make use of `scale_factor=2`
|
156 |
+
if output_size is None:
|
157 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
158 |
+
else:
|
159 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
160 |
+
|
161 |
+
# If the input is bfloat16, we cast back to bfloat16
|
162 |
+
if dtype == torch.bfloat16:
|
163 |
+
hidden_states = hidden_states.to(dtype)
|
164 |
+
|
165 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
166 |
+
if self.use_conv:
|
167 |
+
if self.name == "conv":
|
168 |
+
hidden_states = self.conv(hidden_states)
|
169 |
+
else:
|
170 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
171 |
+
|
172 |
+
return hidden_states
|
173 |
+
|
174 |
+
|
175 |
+
class Downsample2D(nn.Module):
|
176 |
+
"""A 2D downsampling layer with an optional convolution.
|
177 |
+
|
178 |
+
Parameters:
|
179 |
+
channels (`int`):
|
180 |
+
number of channels in the inputs and outputs.
|
181 |
+
use_conv (`bool`, default `False`):
|
182 |
+
option to use a convolution.
|
183 |
+
out_channels (`int`, optional):
|
184 |
+
number of output channels. Defaults to `channels`.
|
185 |
+
padding (`int`, default `1`):
|
186 |
+
padding for the convolution.
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
190 |
+
super().__init__()
|
191 |
+
self.channels = channels
|
192 |
+
self.out_channels = out_channels or channels
|
193 |
+
self.use_conv = use_conv
|
194 |
+
self.padding = padding
|
195 |
+
stride = 2
|
196 |
+
self.name = name
|
197 |
+
|
198 |
+
if use_conv:
|
199 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
200 |
+
else:
|
201 |
+
assert self.channels == self.out_channels
|
202 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
203 |
+
|
204 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
205 |
+
if name == "conv":
|
206 |
+
self.Conv2d_0 = conv
|
207 |
+
self.conv = conv
|
208 |
+
elif name == "Conv2d_0":
|
209 |
+
self.conv = conv
|
210 |
+
else:
|
211 |
+
self.conv = conv
|
212 |
+
|
213 |
+
def forward(self, hidden_states):
|
214 |
+
assert hidden_states.shape[1] == self.channels
|
215 |
+
if self.use_conv and self.padding == 0:
|
216 |
+
pad = (0, 1, 0, 1)
|
217 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
218 |
+
|
219 |
+
assert hidden_states.shape[1] == self.channels
|
220 |
+
hidden_states = self.conv(hidden_states)
|
221 |
+
|
222 |
+
return hidden_states
|
223 |
+
|
224 |
+
|
225 |
+
class FirUpsample2D(nn.Module):
|
226 |
+
"""A 2D FIR upsampling layer with an optional convolution.
|
227 |
+
|
228 |
+
Parameters:
|
229 |
+
channels (`int`):
|
230 |
+
number of channels in the inputs and outputs.
|
231 |
+
use_conv (`bool`, default `False`):
|
232 |
+
option to use a convolution.
|
233 |
+
out_channels (`int`, optional):
|
234 |
+
number of output channels. Defaults to `channels`.
|
235 |
+
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
236 |
+
kernel for the FIR filter.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
240 |
+
super().__init__()
|
241 |
+
out_channels = out_channels if out_channels else channels
|
242 |
+
if use_conv:
|
243 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
244 |
+
self.use_conv = use_conv
|
245 |
+
self.fir_kernel = fir_kernel
|
246 |
+
self.out_channels = out_channels
|
247 |
+
|
248 |
+
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
249 |
+
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
250 |
+
|
251 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
252 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
253 |
+
arbitrary order.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
257 |
+
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
258 |
+
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
259 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
260 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
261 |
+
factor: Integer upsampling factor (default: 2).
|
262 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
266 |
+
datatype as `hidden_states`.
|
267 |
+
"""
|
268 |
+
|
269 |
+
assert isinstance(factor, int) and factor >= 1
|
270 |
+
|
271 |
+
# Setup filter kernel.
|
272 |
+
if kernel is None:
|
273 |
+
kernel = [1] * factor
|
274 |
+
|
275 |
+
# setup kernel
|
276 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
277 |
+
if kernel.ndim == 1:
|
278 |
+
kernel = torch.outer(kernel, kernel)
|
279 |
+
kernel /= torch.sum(kernel)
|
280 |
+
|
281 |
+
kernel = kernel * (gain * (factor**2))
|
282 |
+
|
283 |
+
if self.use_conv:
|
284 |
+
convH = weight.shape[2]
|
285 |
+
convW = weight.shape[3]
|
286 |
+
inC = weight.shape[1]
|
287 |
+
|
288 |
+
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
289 |
+
|
290 |
+
stride = (factor, factor)
|
291 |
+
# Determine data dimensions.
|
292 |
+
output_shape = (
|
293 |
+
(hidden_states.shape[2] - 1) * factor + convH,
|
294 |
+
(hidden_states.shape[3] - 1) * factor + convW,
|
295 |
+
)
|
296 |
+
output_padding = (
|
297 |
+
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
298 |
+
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
299 |
+
)
|
300 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
301 |
+
num_groups = hidden_states.shape[1] // inC
|
302 |
+
|
303 |
+
# Transpose weights.
|
304 |
+
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
305 |
+
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
306 |
+
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
307 |
+
|
308 |
+
inverse_conv = F.conv_transpose2d(
|
309 |
+
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
|
310 |
+
)
|
311 |
+
|
312 |
+
output = upfirdn2d_native(
|
313 |
+
inverse_conv,
|
314 |
+
torch.tensor(kernel, device=inverse_conv.device),
|
315 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
pad_value = kernel.shape[0] - factor
|
319 |
+
output = upfirdn2d_native(
|
320 |
+
hidden_states,
|
321 |
+
torch.tensor(kernel, device=hidden_states.device),
|
322 |
+
up=factor,
|
323 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
324 |
+
)
|
325 |
+
|
326 |
+
return output
|
327 |
+
|
328 |
+
def forward(self, hidden_states):
|
329 |
+
if self.use_conv:
|
330 |
+
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
331 |
+
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
332 |
+
else:
|
333 |
+
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
334 |
+
|
335 |
+
return height
|
336 |
+
|
337 |
+
|
338 |
+
class FirDownsample2D(nn.Module):
|
339 |
+
"""A 2D FIR downsampling layer with an optional convolution.
|
340 |
+
|
341 |
+
Parameters:
|
342 |
+
channels (`int`):
|
343 |
+
number of channels in the inputs and outputs.
|
344 |
+
use_conv (`bool`, default `False`):
|
345 |
+
option to use a convolution.
|
346 |
+
out_channels (`int`, optional):
|
347 |
+
number of output channels. Defaults to `channels`.
|
348 |
+
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
349 |
+
kernel for the FIR filter.
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
353 |
+
super().__init__()
|
354 |
+
out_channels = out_channels if out_channels else channels
|
355 |
+
if use_conv:
|
356 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
357 |
+
self.fir_kernel = fir_kernel
|
358 |
+
self.use_conv = use_conv
|
359 |
+
self.out_channels = out_channels
|
360 |
+
|
361 |
+
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
362 |
+
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
363 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
364 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
365 |
+
arbitrary order.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
369 |
+
weight:
|
370 |
+
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
371 |
+
performed by `inChannels = x.shape[0] // numGroups`.
|
372 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
373 |
+
factor`, which corresponds to average pooling.
|
374 |
+
factor: Integer downsampling factor (default: 2).
|
375 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
|
379 |
+
same datatype as `x`.
|
380 |
+
"""
|
381 |
+
|
382 |
+
assert isinstance(factor, int) and factor >= 1
|
383 |
+
if kernel is None:
|
384 |
+
kernel = [1] * factor
|
385 |
+
|
386 |
+
# setup kernel
|
387 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
388 |
+
if kernel.ndim == 1:
|
389 |
+
kernel = torch.outer(kernel, kernel)
|
390 |
+
kernel /= torch.sum(kernel)
|
391 |
+
|
392 |
+
kernel = kernel * gain
|
393 |
+
|
394 |
+
if self.use_conv:
|
395 |
+
_, _, convH, convW = weight.shape
|
396 |
+
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
397 |
+
stride_value = [factor, factor]
|
398 |
+
upfirdn_input = upfirdn2d_native(
|
399 |
+
hidden_states,
|
400 |
+
torch.tensor(kernel, device=hidden_states.device),
|
401 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
402 |
+
)
|
403 |
+
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
404 |
+
else:
|
405 |
+
pad_value = kernel.shape[0] - factor
|
406 |
+
output = upfirdn2d_native(
|
407 |
+
hidden_states,
|
408 |
+
torch.tensor(kernel, device=hidden_states.device),
|
409 |
+
down=factor,
|
410 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
411 |
+
)
|
412 |
+
|
413 |
+
return output
|
414 |
+
|
415 |
+
def forward(self, hidden_states):
|
416 |
+
if self.use_conv:
|
417 |
+
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
418 |
+
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
419 |
+
else:
|
420 |
+
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
421 |
+
|
422 |
+
return hidden_states
|
423 |
+
|
424 |
+
|
425 |
+
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
426 |
+
class KDownsample2D(nn.Module):
|
427 |
+
def __init__(self, pad_mode="reflect"):
|
428 |
+
super().__init__()
|
429 |
+
self.pad_mode = pad_mode
|
430 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
431 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
432 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
433 |
+
|
434 |
+
def forward(self, inputs):
|
435 |
+
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
436 |
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
437 |
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
438 |
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
439 |
+
weight[indices, indices] = kernel
|
440 |
+
return F.conv2d(inputs, weight, stride=2)
|
441 |
+
|
442 |
+
|
443 |
+
class KUpsample2D(nn.Module):
|
444 |
+
def __init__(self, pad_mode="reflect"):
|
445 |
+
super().__init__()
|
446 |
+
self.pad_mode = pad_mode
|
447 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
448 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
449 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
450 |
+
|
451 |
+
def forward(self, inputs):
|
452 |
+
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
453 |
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
454 |
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
455 |
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
456 |
+
weight[indices, indices] = kernel
|
457 |
+
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
458 |
+
|
459 |
+
|
460 |
+
class ResnetBlock2D(nn.Module):
|
461 |
+
r"""
|
462 |
+
A Resnet block.
|
463 |
+
|
464 |
+
Parameters:
|
465 |
+
in_channels (`int`): The number of channels in the input.
|
466 |
+
out_channels (`int`, *optional*, default to be `None`):
|
467 |
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
468 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
469 |
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
470 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
471 |
+
groups_out (`int`, *optional*, default to None):
|
472 |
+
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
473 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
474 |
+
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
475 |
+
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
476 |
+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
|
477 |
+
"ada_group" for a stronger conditioning with scale and shift.
|
478 |
+
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
|
479 |
+
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
480 |
+
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
481 |
+
use_in_shortcut (`bool`, *optional*, default to `True`):
|
482 |
+
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
483 |
+
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
484 |
+
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
485 |
+
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
486 |
+
`conv_shortcut` output.
|
487 |
+
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
488 |
+
If None, same as `out_channels`.
|
489 |
+
"""
|
490 |
+
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
*,
|
494 |
+
in_channels,
|
495 |
+
out_channels=None,
|
496 |
+
conv_shortcut=False,
|
497 |
+
dropout=0.0,
|
498 |
+
temb_channels=512,
|
499 |
+
groups=32,
|
500 |
+
groups_out=None,
|
501 |
+
pre_norm=True,
|
502 |
+
eps=1e-6,
|
503 |
+
non_linearity="swish",
|
504 |
+
skip_time_act=False,
|
505 |
+
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
|
506 |
+
kernel=None,
|
507 |
+
output_scale_factor=1.0,
|
508 |
+
use_in_shortcut=None,
|
509 |
+
up=False,
|
510 |
+
down=False,
|
511 |
+
conv_shortcut_bias: bool = True,
|
512 |
+
conv_2d_out_channels: Optional[int] = None,
|
513 |
+
):
|
514 |
+
super().__init__()
|
515 |
+
self.pre_norm = pre_norm
|
516 |
+
self.pre_norm = True
|
517 |
+
self.in_channels = in_channels
|
518 |
+
out_channels = in_channels if out_channels is None else out_channels
|
519 |
+
self.out_channels = out_channels
|
520 |
+
self.use_conv_shortcut = conv_shortcut
|
521 |
+
self.up = up
|
522 |
+
self.down = down
|
523 |
+
self.output_scale_factor = output_scale_factor
|
524 |
+
self.time_embedding_norm = time_embedding_norm
|
525 |
+
self.skip_time_act = skip_time_act
|
526 |
+
|
527 |
+
if groups_out is None:
|
528 |
+
groups_out = groups
|
529 |
+
|
530 |
+
if self.time_embedding_norm == "ada_group":
|
531 |
+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
|
532 |
+
elif self.time_embedding_norm == "spatial":
|
533 |
+
self.norm1 = SpatialNorm(in_channels, temb_channels)
|
534 |
+
else:
|
535 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
536 |
+
|
537 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
538 |
+
|
539 |
+
if temb_channels is not None:
|
540 |
+
if self.time_embedding_norm == "default":
|
541 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
542 |
+
elif self.time_embedding_norm == "scale_shift":
|
543 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
544 |
+
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
545 |
+
self.time_emb_proj = None
|
546 |
+
else:
|
547 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
548 |
+
else:
|
549 |
+
self.time_emb_proj = None
|
550 |
+
|
551 |
+
if self.time_embedding_norm == "ada_group":
|
552 |
+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
553 |
+
elif self.time_embedding_norm == "spatial":
|
554 |
+
self.norm2 = SpatialNorm(out_channels, temb_channels)
|
555 |
+
else:
|
556 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
557 |
+
|
558 |
+
self.dropout = torch.nn.Dropout(dropout)
|
559 |
+
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
560 |
+
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
561 |
+
|
562 |
+
self.nonlinearity = get_activation(non_linearity)
|
563 |
+
|
564 |
+
self.upsample = self.downsample = None
|
565 |
+
if self.up:
|
566 |
+
if kernel == "fir":
|
567 |
+
fir_kernel = (1, 3, 3, 1)
|
568 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
569 |
+
elif kernel == "sde_vp":
|
570 |
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
571 |
+
else:
|
572 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
573 |
+
elif self.down:
|
574 |
+
if kernel == "fir":
|
575 |
+
fir_kernel = (1, 3, 3, 1)
|
576 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
577 |
+
elif kernel == "sde_vp":
|
578 |
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
579 |
+
else:
|
580 |
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
581 |
+
|
582 |
+
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
|
583 |
+
|
584 |
+
self.conv_shortcut = None
|
585 |
+
if self.use_in_shortcut:
|
586 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
587 |
+
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
588 |
+
)
|
589 |
+
|
590 |
+
# Rich-Text: feature injection
|
591 |
+
def forward(self, input_tensor, temb, inject_states=None):
|
592 |
+
hidden_states = input_tensor
|
593 |
+
|
594 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
595 |
+
hidden_states = self.norm1(hidden_states, temb)
|
596 |
+
else:
|
597 |
+
hidden_states = self.norm1(hidden_states)
|
598 |
+
|
599 |
+
hidden_states = self.nonlinearity(hidden_states)
|
600 |
+
|
601 |
+
if self.upsample is not None:
|
602 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
603 |
+
if hidden_states.shape[0] >= 64:
|
604 |
+
input_tensor = input_tensor.contiguous()
|
605 |
+
hidden_states = hidden_states.contiguous()
|
606 |
+
input_tensor = self.upsample(input_tensor)
|
607 |
+
hidden_states = self.upsample(hidden_states)
|
608 |
+
elif self.downsample is not None:
|
609 |
+
input_tensor = self.downsample(input_tensor)
|
610 |
+
hidden_states = self.downsample(hidden_states)
|
611 |
+
|
612 |
+
hidden_states = self.conv1(hidden_states)
|
613 |
+
|
614 |
+
if self.time_emb_proj is not None:
|
615 |
+
if not self.skip_time_act:
|
616 |
+
temb = self.nonlinearity(temb)
|
617 |
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
618 |
+
|
619 |
+
if temb is not None and self.time_embedding_norm == "default":
|
620 |
+
hidden_states = hidden_states + temb
|
621 |
+
|
622 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
623 |
+
hidden_states = self.norm2(hidden_states, temb)
|
624 |
+
else:
|
625 |
+
hidden_states = self.norm2(hidden_states)
|
626 |
+
|
627 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
628 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
629 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
630 |
+
|
631 |
+
hidden_states = self.nonlinearity(hidden_states)
|
632 |
+
|
633 |
+
hidden_states = self.dropout(hidden_states)
|
634 |
+
hidden_states = self.conv2(hidden_states)
|
635 |
+
|
636 |
+
if self.conv_shortcut is not None:
|
637 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
638 |
+
|
639 |
+
# Rich-Text: feature injection
|
640 |
+
if inject_states is not None:
|
641 |
+
output_tensor = (input_tensor + inject_states) / self.output_scale_factor
|
642 |
+
else:
|
643 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
644 |
+
|
645 |
+
return output_tensor, hidden_states
|
646 |
+
|
647 |
+
|
648 |
+
# unet_rl.py
|
649 |
+
def rearrange_dims(tensor):
|
650 |
+
if len(tensor.shape) == 2:
|
651 |
+
return tensor[:, :, None]
|
652 |
+
if len(tensor.shape) == 3:
|
653 |
+
return tensor[:, :, None, :]
|
654 |
+
elif len(tensor.shape) == 4:
|
655 |
+
return tensor[:, :, 0, :]
|
656 |
+
else:
|
657 |
+
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
658 |
+
|
659 |
+
|
660 |
+
class Conv1dBlock(nn.Module):
|
661 |
+
"""
|
662 |
+
Conv1d --> GroupNorm --> Mish
|
663 |
+
"""
|
664 |
+
|
665 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
666 |
+
super().__init__()
|
667 |
+
|
668 |
+
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
669 |
+
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
670 |
+
self.mish = nn.Mish()
|
671 |
+
|
672 |
+
def forward(self, inputs):
|
673 |
+
intermediate_repr = self.conv1d(inputs)
|
674 |
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
675 |
+
intermediate_repr = self.group_norm(intermediate_repr)
|
676 |
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
677 |
+
output = self.mish(intermediate_repr)
|
678 |
+
return output
|
679 |
+
|
680 |
+
|
681 |
+
# unet_rl.py
|
682 |
+
class ResidualTemporalBlock1D(nn.Module):
|
683 |
+
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
684 |
+
super().__init__()
|
685 |
+
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
686 |
+
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
687 |
+
|
688 |
+
self.time_emb_act = nn.Mish()
|
689 |
+
self.time_emb = nn.Linear(embed_dim, out_channels)
|
690 |
+
|
691 |
+
self.residual_conv = (
|
692 |
+
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
693 |
+
)
|
694 |
+
|
695 |
+
def forward(self, inputs, t):
|
696 |
+
"""
|
697 |
+
Args:
|
698 |
+
inputs : [ batch_size x inp_channels x horizon ]
|
699 |
+
t : [ batch_size x embed_dim ]
|
700 |
+
|
701 |
+
returns:
|
702 |
+
out : [ batch_size x out_channels x horizon ]
|
703 |
+
"""
|
704 |
+
t = self.time_emb_act(t)
|
705 |
+
t = self.time_emb(t)
|
706 |
+
out = self.conv_in(inputs) + rearrange_dims(t)
|
707 |
+
out = self.conv_out(out)
|
708 |
+
return out + self.residual_conv(inputs)
|
709 |
+
|
710 |
+
|
711 |
+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
712 |
+
r"""Upsample2D a batch of 2D images with the given filter.
|
713 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
714 |
+
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
715 |
+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
716 |
+
a: multiple of the upsampling factor.
|
717 |
+
|
718 |
+
Args:
|
719 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
720 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
721 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
722 |
+
factor: Integer upsampling factor (default: 2).
|
723 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
724 |
+
|
725 |
+
Returns:
|
726 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
727 |
+
"""
|
728 |
+
assert isinstance(factor, int) and factor >= 1
|
729 |
+
if kernel is None:
|
730 |
+
kernel = [1] * factor
|
731 |
+
|
732 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
733 |
+
if kernel.ndim == 1:
|
734 |
+
kernel = torch.outer(kernel, kernel)
|
735 |
+
kernel /= torch.sum(kernel)
|
736 |
+
|
737 |
+
kernel = kernel * (gain * (factor**2))
|
738 |
+
pad_value = kernel.shape[0] - factor
|
739 |
+
output = upfirdn2d_native(
|
740 |
+
hidden_states,
|
741 |
+
kernel.to(device=hidden_states.device),
|
742 |
+
up=factor,
|
743 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
744 |
+
)
|
745 |
+
return output
|
746 |
+
|
747 |
+
|
748 |
+
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
749 |
+
r"""Downsample2D a batch of 2D images with the given filter.
|
750 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
751 |
+
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
752 |
+
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
753 |
+
shape is a multiple of the downsampling factor.
|
754 |
+
|
755 |
+
Args:
|
756 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
757 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
758 |
+
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
759 |
+
factor: Integer downsampling factor (default: 2).
|
760 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
761 |
+
|
762 |
+
Returns:
|
763 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
764 |
+
"""
|
765 |
+
|
766 |
+
assert isinstance(factor, int) and factor >= 1
|
767 |
+
if kernel is None:
|
768 |
+
kernel = [1] * factor
|
769 |
+
|
770 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
771 |
+
if kernel.ndim == 1:
|
772 |
+
kernel = torch.outer(kernel, kernel)
|
773 |
+
kernel /= torch.sum(kernel)
|
774 |
+
|
775 |
+
kernel = kernel * gain
|
776 |
+
pad_value = kernel.shape[0] - factor
|
777 |
+
output = upfirdn2d_native(
|
778 |
+
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
|
779 |
+
)
|
780 |
+
return output
|
781 |
+
|
782 |
+
|
783 |
+
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
784 |
+
up_x = up_y = up
|
785 |
+
down_x = down_y = down
|
786 |
+
pad_x0 = pad_y0 = pad[0]
|
787 |
+
pad_x1 = pad_y1 = pad[1]
|
788 |
+
|
789 |
+
_, channel, in_h, in_w = tensor.shape
|
790 |
+
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
791 |
+
|
792 |
+
_, in_h, in_w, minor = tensor.shape
|
793 |
+
kernel_h, kernel_w = kernel.shape
|
794 |
+
|
795 |
+
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
796 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
797 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
798 |
+
|
799 |
+
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
800 |
+
out = out.to(tensor.device) # Move back to mps if necessary
|
801 |
+
out = out[
|
802 |
+
:,
|
803 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
804 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
805 |
+
:,
|
806 |
+
]
|
807 |
+
|
808 |
+
out = out.permute(0, 3, 1, 2)
|
809 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
810 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
811 |
+
out = F.conv2d(out, w)
|
812 |
+
out = out.reshape(
|
813 |
+
-1,
|
814 |
+
minor,
|
815 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
816 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
817 |
+
)
|
818 |
+
out = out.permute(0, 2, 3, 1)
|
819 |
+
out = out[:, ::down_y, ::down_x, :]
|
820 |
+
|
821 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
822 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
823 |
+
|
824 |
+
return out.view(-1, channel, out_h, out_w)
|
825 |
+
|
826 |
+
|
827 |
+
class TemporalConvLayer(nn.Module):
|
828 |
+
"""
|
829 |
+
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
830 |
+
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
831 |
+
"""
|
832 |
+
|
833 |
+
def __init__(self, in_dim, out_dim=None, dropout=0.0):
|
834 |
+
super().__init__()
|
835 |
+
out_dim = out_dim or in_dim
|
836 |
+
self.in_dim = in_dim
|
837 |
+
self.out_dim = out_dim
|
838 |
+
|
839 |
+
# conv layers
|
840 |
+
self.conv1 = nn.Sequential(
|
841 |
+
nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
|
842 |
+
)
|
843 |
+
self.conv2 = nn.Sequential(
|
844 |
+
nn.GroupNorm(32, out_dim),
|
845 |
+
nn.SiLU(),
|
846 |
+
nn.Dropout(dropout),
|
847 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
848 |
+
)
|
849 |
+
self.conv3 = nn.Sequential(
|
850 |
+
nn.GroupNorm(32, out_dim),
|
851 |
+
nn.SiLU(),
|
852 |
+
nn.Dropout(dropout),
|
853 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
854 |
+
)
|
855 |
+
self.conv4 = nn.Sequential(
|
856 |
+
nn.GroupNorm(32, out_dim),
|
857 |
+
nn.SiLU(),
|
858 |
+
nn.Dropout(dropout),
|
859 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
860 |
+
)
|
861 |
+
|
862 |
+
# zero out the last layer params,so the conv block is identity
|
863 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
864 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
865 |
+
|
866 |
+
def forward(self, hidden_states, num_frames=1):
|
867 |
+
hidden_states = (
|
868 |
+
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
|
869 |
+
)
|
870 |
+
|
871 |
+
identity = hidden_states
|
872 |
+
hidden_states = self.conv1(hidden_states)
|
873 |
+
hidden_states = self.conv2(hidden_states)
|
874 |
+
hidden_states = self.conv3(hidden_states)
|
875 |
+
hidden_states = self.conv4(hidden_states)
|
876 |
+
|
877 |
+
hidden_states = identity + hidden_states
|
878 |
+
|
879 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
|
880 |
+
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
881 |
+
)
|
882 |
+
return hidden_states
|
models/transformer_2d.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
23 |
+
from diffusers.utils import BaseOutput, deprecate
|
24 |
+
from diffusers.models.embeddings import PatchEmbed
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
|
27 |
+
from models.attention import BasicTransformerBlock
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Transformer2DModelOutput(BaseOutput):
|
31 |
+
"""
|
32 |
+
The output of [`Transformer2DModel`].
|
33 |
+
|
34 |
+
Args:
|
35 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
36 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
37 |
+
distributions for the unnoised latent pixels.
|
38 |
+
"""
|
39 |
+
|
40 |
+
sample: torch.FloatTensor
|
41 |
+
|
42 |
+
|
43 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
44 |
+
"""
|
45 |
+
A 2D Transformer model for image-like data.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
49 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
50 |
+
in_channels (`int`, *optional*):
|
51 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
52 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
53 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
54 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
55 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
56 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
57 |
+
num_vector_embeds (`int`, *optional*):
|
58 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
59 |
+
Includes the class for the masked latent pixel.
|
60 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
61 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
62 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
63 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
64 |
+
added to the hidden states.
|
65 |
+
|
66 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
67 |
+
attention_bias (`bool`, *optional*):
|
68 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
69 |
+
"""
|
70 |
+
|
71 |
+
@register_to_config
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
num_attention_heads: int = 16,
|
75 |
+
attention_head_dim: int = 88,
|
76 |
+
in_channels: Optional[int] = None,
|
77 |
+
out_channels: Optional[int] = None,
|
78 |
+
num_layers: int = 1,
|
79 |
+
dropout: float = 0.0,
|
80 |
+
norm_num_groups: int = 32,
|
81 |
+
cross_attention_dim: Optional[int] = None,
|
82 |
+
attention_bias: bool = False,
|
83 |
+
sample_size: Optional[int] = None,
|
84 |
+
num_vector_embeds: Optional[int] = None,
|
85 |
+
patch_size: Optional[int] = None,
|
86 |
+
activation_fn: str = "geglu",
|
87 |
+
num_embeds_ada_norm: Optional[int] = None,
|
88 |
+
use_linear_projection: bool = False,
|
89 |
+
only_cross_attention: bool = False,
|
90 |
+
upcast_attention: bool = False,
|
91 |
+
norm_type: str = "layer_norm",
|
92 |
+
norm_elementwise_affine: bool = True,
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
self.use_linear_projection = use_linear_projection
|
96 |
+
self.num_attention_heads = num_attention_heads
|
97 |
+
self.attention_head_dim = attention_head_dim
|
98 |
+
inner_dim = num_attention_heads * attention_head_dim
|
99 |
+
|
100 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
101 |
+
# Define whether input is continuous or discrete depending on configuration
|
102 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
103 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
104 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
105 |
+
|
106 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
107 |
+
deprecation_message = (
|
108 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
109 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
110 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
111 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
112 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
113 |
+
)
|
114 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
115 |
+
norm_type = "ada_norm"
|
116 |
+
|
117 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
118 |
+
raise ValueError(
|
119 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
120 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
121 |
+
)
|
122 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
123 |
+
raise ValueError(
|
124 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
125 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
126 |
+
)
|
127 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
128 |
+
raise ValueError(
|
129 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
130 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
131 |
+
)
|
132 |
+
|
133 |
+
# 2. Define input layers
|
134 |
+
if self.is_input_continuous:
|
135 |
+
self.in_channels = in_channels
|
136 |
+
|
137 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
138 |
+
if use_linear_projection:
|
139 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
140 |
+
else:
|
141 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
142 |
+
elif self.is_input_vectorized:
|
143 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
144 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
145 |
+
|
146 |
+
self.height = sample_size
|
147 |
+
self.width = sample_size
|
148 |
+
self.num_vector_embeds = num_vector_embeds
|
149 |
+
self.num_latent_pixels = self.height * self.width
|
150 |
+
|
151 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
152 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
153 |
+
)
|
154 |
+
elif self.is_input_patches:
|
155 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
156 |
+
|
157 |
+
self.height = sample_size
|
158 |
+
self.width = sample_size
|
159 |
+
|
160 |
+
self.patch_size = patch_size
|
161 |
+
self.pos_embed = PatchEmbed(
|
162 |
+
height=sample_size,
|
163 |
+
width=sample_size,
|
164 |
+
patch_size=patch_size,
|
165 |
+
in_channels=in_channels,
|
166 |
+
embed_dim=inner_dim,
|
167 |
+
)
|
168 |
+
|
169 |
+
# 3. Define transformers blocks
|
170 |
+
self.transformer_blocks = nn.ModuleList(
|
171 |
+
[
|
172 |
+
BasicTransformerBlock(
|
173 |
+
inner_dim,
|
174 |
+
num_attention_heads,
|
175 |
+
attention_head_dim,
|
176 |
+
dropout=dropout,
|
177 |
+
cross_attention_dim=cross_attention_dim,
|
178 |
+
activation_fn=activation_fn,
|
179 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
180 |
+
attention_bias=attention_bias,
|
181 |
+
only_cross_attention=only_cross_attention,
|
182 |
+
upcast_attention=upcast_attention,
|
183 |
+
norm_type=norm_type,
|
184 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
185 |
+
)
|
186 |
+
for d in range(num_layers)
|
187 |
+
]
|
188 |
+
)
|
189 |
+
|
190 |
+
# 4. Define output layers
|
191 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
192 |
+
if self.is_input_continuous:
|
193 |
+
# TODO: should use out_channels for continuous projections
|
194 |
+
if use_linear_projection:
|
195 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
196 |
+
else:
|
197 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
198 |
+
elif self.is_input_vectorized:
|
199 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
200 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
201 |
+
elif self.is_input_patches:
|
202 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
203 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
204 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
205 |
+
|
206 |
+
def forward(
|
207 |
+
self,
|
208 |
+
hidden_states: torch.Tensor,
|
209 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
210 |
+
timestep: Optional[torch.LongTensor] = None,
|
211 |
+
class_labels: Optional[torch.LongTensor] = None,
|
212 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
213 |
+
attention_mask: Optional[torch.Tensor] = None,
|
214 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
215 |
+
return_dict: bool = True,
|
216 |
+
):
|
217 |
+
"""
|
218 |
+
The [`Transformer2DModel`] forward method.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
222 |
+
Input `hidden_states`.
|
223 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
224 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
225 |
+
self-attention.
|
226 |
+
timestep ( `torch.LongTensor`, *optional*):
|
227 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
228 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
229 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
230 |
+
`AdaLayerZeroNorm`.
|
231 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
232 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
233 |
+
|
234 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
235 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
236 |
+
|
237 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
238 |
+
above. This bias will be added to the cross-attention scores.
|
239 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
240 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
241 |
+
tuple.
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
245 |
+
`tuple` where the first element is the sample tensor.
|
246 |
+
"""
|
247 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
248 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
249 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
250 |
+
# expects mask of shape:
|
251 |
+
# [batch, key_tokens]
|
252 |
+
# adds singleton query_tokens dimension:
|
253 |
+
# [batch, 1, key_tokens]
|
254 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
255 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
256 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
257 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
258 |
+
# assume that mask is expressed as:
|
259 |
+
# (1 = keep, 0 = discard)
|
260 |
+
# convert mask into a bias that can be added to attention scores:
|
261 |
+
# (keep = +0, discard = -10000.0)
|
262 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
263 |
+
attention_mask = attention_mask.unsqueeze(1)
|
264 |
+
|
265 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
266 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
267 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
268 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
269 |
+
|
270 |
+
# 1. Input
|
271 |
+
if self.is_input_continuous:
|
272 |
+
batch, _, height, width = hidden_states.shape
|
273 |
+
residual = hidden_states
|
274 |
+
|
275 |
+
hidden_states = self.norm(hidden_states)
|
276 |
+
if not self.use_linear_projection:
|
277 |
+
hidden_states = self.proj_in(hidden_states)
|
278 |
+
inner_dim = hidden_states.shape[1]
|
279 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
280 |
+
else:
|
281 |
+
inner_dim = hidden_states.shape[1]
|
282 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
283 |
+
hidden_states = self.proj_in(hidden_states)
|
284 |
+
elif self.is_input_vectorized:
|
285 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
286 |
+
elif self.is_input_patches:
|
287 |
+
hidden_states = self.pos_embed(hidden_states)
|
288 |
+
|
289 |
+
# 2. Blocks
|
290 |
+
for block in self.transformer_blocks:
|
291 |
+
hidden_states = block(
|
292 |
+
hidden_states,
|
293 |
+
attention_mask=attention_mask,
|
294 |
+
encoder_hidden_states=encoder_hidden_states,
|
295 |
+
encoder_attention_mask=encoder_attention_mask,
|
296 |
+
timestep=timestep,
|
297 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
298 |
+
class_labels=class_labels,
|
299 |
+
)
|
300 |
+
|
301 |
+
# 3. Output
|
302 |
+
if self.is_input_continuous:
|
303 |
+
if not self.use_linear_projection:
|
304 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
305 |
+
hidden_states = self.proj_out(hidden_states)
|
306 |
+
else:
|
307 |
+
hidden_states = self.proj_out(hidden_states)
|
308 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
309 |
+
|
310 |
+
output = hidden_states + residual
|
311 |
+
elif self.is_input_vectorized:
|
312 |
+
hidden_states = self.norm_out(hidden_states)
|
313 |
+
logits = self.out(hidden_states)
|
314 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
315 |
+
logits = logits.permute(0, 2, 1)
|
316 |
+
|
317 |
+
# log(p(x_0))
|
318 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
319 |
+
elif self.is_input_patches:
|
320 |
+
# TODO: cleanup!
|
321 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
322 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
323 |
+
)
|
324 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
325 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
326 |
+
hidden_states = self.proj_out_2(hidden_states)
|
327 |
+
|
328 |
+
# unpatchify
|
329 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
330 |
+
hidden_states = hidden_states.reshape(
|
331 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
332 |
+
)
|
333 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
334 |
+
output = hidden_states.reshape(
|
335 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
336 |
+
)
|
337 |
+
|
338 |
+
if not return_dict:
|
339 |
+
return (output,)
|
340 |
+
|
341 |
+
return Transformer2DModelOutput(sample=output)
|
models/unet_2d_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/unet_2d_condition.py
ADDED
@@ -0,0 +1,983 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.activations import get_activation
|
25 |
+
|
26 |
+
from diffusers.models.embeddings import (
|
27 |
+
GaussianFourierProjection,
|
28 |
+
ImageHintTimeEmbedding,
|
29 |
+
ImageProjection,
|
30 |
+
ImageTimeEmbedding,
|
31 |
+
TextImageProjection,
|
32 |
+
TextImageTimeEmbedding,
|
33 |
+
TextTimeEmbedding,
|
34 |
+
TimestepEmbedding,
|
35 |
+
Timesteps,
|
36 |
+
)
|
37 |
+
from diffusers.models.modeling_utils import ModelMixin
|
38 |
+
|
39 |
+
from models.attention_processor import AttentionProcessor, AttnProcessor
|
40 |
+
|
41 |
+
from models.unet_2d_blocks import (
|
42 |
+
CrossAttnDownBlock2D,
|
43 |
+
CrossAttnUpBlock2D,
|
44 |
+
DownBlock2D,
|
45 |
+
UNetMidBlock2DCrossAttn,
|
46 |
+
UNetMidBlock2DSimpleCrossAttn,
|
47 |
+
UpBlock2D,
|
48 |
+
get_down_block,
|
49 |
+
get_up_block,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class UNet2DConditionOutput(BaseOutput):
|
58 |
+
"""
|
59 |
+
The output of [`UNet2DConditionModel`].
|
60 |
+
|
61 |
+
Args:
|
62 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
63 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
64 |
+
"""
|
65 |
+
|
66 |
+
sample: torch.FloatTensor = None
|
67 |
+
|
68 |
+
|
69 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
70 |
+
r"""
|
71 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
72 |
+
shaped output.
|
73 |
+
|
74 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
75 |
+
for all models (such as downloading or saving).
|
76 |
+
|
77 |
+
Parameters:
|
78 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
79 |
+
Height and width of input/output sample.
|
80 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
81 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
82 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
83 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
84 |
+
Whether to flip the sin to cos in the time embedding.
|
85 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
86 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
87 |
+
The tuple of downsample blocks to use.
|
88 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
89 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
|
90 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
91 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
92 |
+
The tuple of upsample blocks to use.
|
93 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
94 |
+
Whether to include self-attention in the basic transformer blocks, see
|
95 |
+
[`~models.attention.BasicTransformerBlock`].
|
96 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
97 |
+
The tuple of output channels for each block.
|
98 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
99 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
100 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
101 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
102 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
103 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
104 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
105 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
106 |
+
The dimension of the cross attention features.
|
107 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
108 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
109 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
110 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
111 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
112 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
113 |
+
dimension to `cross_attention_dim`.
|
114 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
115 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
116 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
117 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
118 |
+
num_attention_heads (`int`, *optional*):
|
119 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
120 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
121 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
122 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
123 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
124 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
125 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
126 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
127 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
128 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
129 |
+
Dimension for the timestep embeddings.
|
130 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
131 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
132 |
+
class conditioning with `class_embed_type` equal to `None`.
|
133 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
134 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
135 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
136 |
+
An optional override for the dimension of the projected time embedding.
|
137 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
138 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
139 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
140 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
141 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
142 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
143 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
144 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
145 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
146 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
147 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
148 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
149 |
+
embeddings with the class embeddings.
|
150 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
151 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
152 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
153 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
154 |
+
otherwise.
|
155 |
+
"""
|
156 |
+
|
157 |
+
_supports_gradient_checkpointing = True
|
158 |
+
|
159 |
+
@register_to_config
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
sample_size: Optional[int] = None,
|
163 |
+
in_channels: int = 4,
|
164 |
+
out_channels: int = 4,
|
165 |
+
center_input_sample: bool = False,
|
166 |
+
flip_sin_to_cos: bool = True,
|
167 |
+
freq_shift: int = 0,
|
168 |
+
down_block_types: Tuple[str] = (
|
169 |
+
"CrossAttnDownBlock2D",
|
170 |
+
"CrossAttnDownBlock2D",
|
171 |
+
"CrossAttnDownBlock2D",
|
172 |
+
"DownBlock2D",
|
173 |
+
),
|
174 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
175 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
176 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
177 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
178 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
179 |
+
downsample_padding: int = 1,
|
180 |
+
mid_block_scale_factor: float = 1,
|
181 |
+
act_fn: str = "silu",
|
182 |
+
norm_num_groups: Optional[int] = 32,
|
183 |
+
norm_eps: float = 1e-5,
|
184 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
185 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
186 |
+
encoder_hid_dim: Optional[int] = None,
|
187 |
+
encoder_hid_dim_type: Optional[str] = None,
|
188 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
189 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
190 |
+
dual_cross_attention: bool = False,
|
191 |
+
use_linear_projection: bool = False,
|
192 |
+
class_embed_type: Optional[str] = None,
|
193 |
+
addition_embed_type: Optional[str] = None,
|
194 |
+
addition_time_embed_dim: Optional[int] = None,
|
195 |
+
num_class_embeds: Optional[int] = None,
|
196 |
+
upcast_attention: bool = False,
|
197 |
+
resnet_time_scale_shift: str = "default",
|
198 |
+
resnet_skip_time_act: bool = False,
|
199 |
+
resnet_out_scale_factor: int = 1.0,
|
200 |
+
time_embedding_type: str = "positional",
|
201 |
+
time_embedding_dim: Optional[int] = None,
|
202 |
+
time_embedding_act_fn: Optional[str] = None,
|
203 |
+
timestep_post_act: Optional[str] = None,
|
204 |
+
time_cond_proj_dim: Optional[int] = None,
|
205 |
+
conv_in_kernel: int = 3,
|
206 |
+
conv_out_kernel: int = 3,
|
207 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
208 |
+
class_embeddings_concat: bool = False,
|
209 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
210 |
+
cross_attention_norm: Optional[str] = None,
|
211 |
+
addition_embed_type_num_heads=64,
|
212 |
+
):
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
self.sample_size = sample_size
|
216 |
+
|
217 |
+
if num_attention_heads is not None:
|
218 |
+
raise ValueError(
|
219 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
220 |
+
)
|
221 |
+
|
222 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
223 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
224 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
225 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
226 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
227 |
+
# which is why we correct for the naming here.
|
228 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
229 |
+
|
230 |
+
# Check inputs
|
231 |
+
if len(down_block_types) != len(up_block_types):
|
232 |
+
raise ValueError(
|
233 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
234 |
+
)
|
235 |
+
|
236 |
+
if len(block_out_channels) != len(down_block_types):
|
237 |
+
raise ValueError(
|
238 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
239 |
+
)
|
240 |
+
|
241 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
242 |
+
raise ValueError(
|
243 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
244 |
+
)
|
245 |
+
|
246 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
247 |
+
raise ValueError(
|
248 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
249 |
+
)
|
250 |
+
|
251 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
252 |
+
raise ValueError(
|
253 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
254 |
+
)
|
255 |
+
|
256 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
257 |
+
raise ValueError(
|
258 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
259 |
+
)
|
260 |
+
|
261 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
262 |
+
raise ValueError(
|
263 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
264 |
+
)
|
265 |
+
|
266 |
+
# input
|
267 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
268 |
+
self.conv_in = nn.Conv2d(
|
269 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
270 |
+
)
|
271 |
+
|
272 |
+
# time
|
273 |
+
if time_embedding_type == "fourier":
|
274 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
275 |
+
if time_embed_dim % 2 != 0:
|
276 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
277 |
+
self.time_proj = GaussianFourierProjection(
|
278 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
279 |
+
)
|
280 |
+
timestep_input_dim = time_embed_dim
|
281 |
+
elif time_embedding_type == "positional":
|
282 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
283 |
+
|
284 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
285 |
+
timestep_input_dim = block_out_channels[0]
|
286 |
+
else:
|
287 |
+
raise ValueError(
|
288 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
289 |
+
)
|
290 |
+
|
291 |
+
self.time_embedding = TimestepEmbedding(
|
292 |
+
timestep_input_dim,
|
293 |
+
time_embed_dim,
|
294 |
+
act_fn=act_fn,
|
295 |
+
post_act_fn=timestep_post_act,
|
296 |
+
cond_proj_dim=time_cond_proj_dim,
|
297 |
+
)
|
298 |
+
|
299 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
300 |
+
encoder_hid_dim_type = "text_proj"
|
301 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
302 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
303 |
+
|
304 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
305 |
+
raise ValueError(
|
306 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
307 |
+
)
|
308 |
+
|
309 |
+
if encoder_hid_dim_type == "text_proj":
|
310 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
311 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
312 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
313 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
314 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
315 |
+
self.encoder_hid_proj = TextImageProjection(
|
316 |
+
text_embed_dim=encoder_hid_dim,
|
317 |
+
image_embed_dim=cross_attention_dim,
|
318 |
+
cross_attention_dim=cross_attention_dim,
|
319 |
+
)
|
320 |
+
elif encoder_hid_dim_type == "image_proj":
|
321 |
+
# Kandinsky 2.2
|
322 |
+
self.encoder_hid_proj = ImageProjection(
|
323 |
+
image_embed_dim=encoder_hid_dim,
|
324 |
+
cross_attention_dim=cross_attention_dim,
|
325 |
+
)
|
326 |
+
elif encoder_hid_dim_type is not None:
|
327 |
+
raise ValueError(
|
328 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
329 |
+
)
|
330 |
+
else:
|
331 |
+
self.encoder_hid_proj = None
|
332 |
+
|
333 |
+
# class embedding
|
334 |
+
if class_embed_type is None and num_class_embeds is not None:
|
335 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
336 |
+
elif class_embed_type == "timestep":
|
337 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
338 |
+
elif class_embed_type == "identity":
|
339 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
340 |
+
elif class_embed_type == "projection":
|
341 |
+
if projection_class_embeddings_input_dim is None:
|
342 |
+
raise ValueError(
|
343 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
344 |
+
)
|
345 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
346 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
347 |
+
# 2. it projects from an arbitrary input dimension.
|
348 |
+
#
|
349 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
350 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
351 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
352 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
353 |
+
elif class_embed_type == "simple_projection":
|
354 |
+
if projection_class_embeddings_input_dim is None:
|
355 |
+
raise ValueError(
|
356 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
357 |
+
)
|
358 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
359 |
+
else:
|
360 |
+
self.class_embedding = None
|
361 |
+
|
362 |
+
if addition_embed_type == "text":
|
363 |
+
if encoder_hid_dim is not None:
|
364 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
365 |
+
else:
|
366 |
+
text_time_embedding_from_dim = cross_attention_dim
|
367 |
+
|
368 |
+
self.add_embedding = TextTimeEmbedding(
|
369 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
370 |
+
)
|
371 |
+
elif addition_embed_type == "text_image":
|
372 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
373 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
374 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
375 |
+
self.add_embedding = TextImageTimeEmbedding(
|
376 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
377 |
+
)
|
378 |
+
elif addition_embed_type == "text_time":
|
379 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
380 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
381 |
+
elif addition_embed_type == "image":
|
382 |
+
# Kandinsky 2.2
|
383 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
384 |
+
elif addition_embed_type == "image_hint":
|
385 |
+
# Kandinsky 2.2 ControlNet
|
386 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
387 |
+
elif addition_embed_type is not None:
|
388 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
389 |
+
|
390 |
+
if time_embedding_act_fn is None:
|
391 |
+
self.time_embed_act = None
|
392 |
+
else:
|
393 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
394 |
+
|
395 |
+
self.down_blocks = nn.ModuleList([])
|
396 |
+
self.up_blocks = nn.ModuleList([])
|
397 |
+
|
398 |
+
if isinstance(only_cross_attention, bool):
|
399 |
+
if mid_block_only_cross_attention is None:
|
400 |
+
mid_block_only_cross_attention = only_cross_attention
|
401 |
+
|
402 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
403 |
+
|
404 |
+
if mid_block_only_cross_attention is None:
|
405 |
+
mid_block_only_cross_attention = False
|
406 |
+
|
407 |
+
if isinstance(num_attention_heads, int):
|
408 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
409 |
+
|
410 |
+
if isinstance(attention_head_dim, int):
|
411 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
412 |
+
|
413 |
+
if isinstance(cross_attention_dim, int):
|
414 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
415 |
+
|
416 |
+
if isinstance(layers_per_block, int):
|
417 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
418 |
+
|
419 |
+
if isinstance(transformer_layers_per_block, int):
|
420 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
421 |
+
|
422 |
+
if class_embeddings_concat:
|
423 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
424 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
425 |
+
# regular time embeddings
|
426 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
427 |
+
else:
|
428 |
+
blocks_time_embed_dim = time_embed_dim
|
429 |
+
|
430 |
+
# down
|
431 |
+
output_channel = block_out_channels[0]
|
432 |
+
for i, down_block_type in enumerate(down_block_types):
|
433 |
+
input_channel = output_channel
|
434 |
+
output_channel = block_out_channels[i]
|
435 |
+
is_final_block = i == len(block_out_channels) - 1
|
436 |
+
|
437 |
+
down_block = get_down_block(
|
438 |
+
down_block_type,
|
439 |
+
num_layers=layers_per_block[i],
|
440 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
441 |
+
in_channels=input_channel,
|
442 |
+
out_channels=output_channel,
|
443 |
+
temb_channels=blocks_time_embed_dim,
|
444 |
+
add_downsample=not is_final_block,
|
445 |
+
resnet_eps=norm_eps,
|
446 |
+
resnet_act_fn=act_fn,
|
447 |
+
resnet_groups=norm_num_groups,
|
448 |
+
cross_attention_dim=cross_attention_dim[i],
|
449 |
+
num_attention_heads=num_attention_heads[i],
|
450 |
+
downsample_padding=downsample_padding,
|
451 |
+
dual_cross_attention=dual_cross_attention,
|
452 |
+
use_linear_projection=use_linear_projection,
|
453 |
+
only_cross_attention=only_cross_attention[i],
|
454 |
+
upcast_attention=upcast_attention,
|
455 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
456 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
457 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
458 |
+
cross_attention_norm=cross_attention_norm,
|
459 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
460 |
+
)
|
461 |
+
self.down_blocks.append(down_block)
|
462 |
+
|
463 |
+
# mid
|
464 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
465 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
466 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
467 |
+
in_channels=block_out_channels[-1],
|
468 |
+
temb_channels=blocks_time_embed_dim,
|
469 |
+
resnet_eps=norm_eps,
|
470 |
+
resnet_act_fn=act_fn,
|
471 |
+
output_scale_factor=mid_block_scale_factor,
|
472 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
473 |
+
cross_attention_dim=cross_attention_dim[-1],
|
474 |
+
num_attention_heads=num_attention_heads[-1],
|
475 |
+
resnet_groups=norm_num_groups,
|
476 |
+
dual_cross_attention=dual_cross_attention,
|
477 |
+
use_linear_projection=use_linear_projection,
|
478 |
+
upcast_attention=upcast_attention,
|
479 |
+
)
|
480 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
481 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
482 |
+
in_channels=block_out_channels[-1],
|
483 |
+
temb_channels=blocks_time_embed_dim,
|
484 |
+
resnet_eps=norm_eps,
|
485 |
+
resnet_act_fn=act_fn,
|
486 |
+
output_scale_factor=mid_block_scale_factor,
|
487 |
+
cross_attention_dim=cross_attention_dim[-1],
|
488 |
+
attention_head_dim=attention_head_dim[-1],
|
489 |
+
resnet_groups=norm_num_groups,
|
490 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
491 |
+
skip_time_act=resnet_skip_time_act,
|
492 |
+
only_cross_attention=mid_block_only_cross_attention,
|
493 |
+
cross_attention_norm=cross_attention_norm,
|
494 |
+
)
|
495 |
+
elif mid_block_type is None:
|
496 |
+
self.mid_block = None
|
497 |
+
else:
|
498 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
499 |
+
|
500 |
+
# count how many layers upsample the images
|
501 |
+
self.num_upsamplers = 0
|
502 |
+
|
503 |
+
# up
|
504 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
505 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
506 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
507 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
508 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
509 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
510 |
+
|
511 |
+
output_channel = reversed_block_out_channels[0]
|
512 |
+
for i, up_block_type in enumerate(up_block_types):
|
513 |
+
is_final_block = i == len(block_out_channels) - 1
|
514 |
+
|
515 |
+
prev_output_channel = output_channel
|
516 |
+
output_channel = reversed_block_out_channels[i]
|
517 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
518 |
+
|
519 |
+
# add upsample block for all BUT final layer
|
520 |
+
if not is_final_block:
|
521 |
+
add_upsample = True
|
522 |
+
self.num_upsamplers += 1
|
523 |
+
else:
|
524 |
+
add_upsample = False
|
525 |
+
|
526 |
+
up_block = get_up_block(
|
527 |
+
up_block_type,
|
528 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
529 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
530 |
+
in_channels=input_channel,
|
531 |
+
out_channels=output_channel,
|
532 |
+
prev_output_channel=prev_output_channel,
|
533 |
+
temb_channels=blocks_time_embed_dim,
|
534 |
+
add_upsample=add_upsample,
|
535 |
+
resnet_eps=norm_eps,
|
536 |
+
resnet_act_fn=act_fn,
|
537 |
+
resnet_groups=norm_num_groups,
|
538 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
539 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
540 |
+
dual_cross_attention=dual_cross_attention,
|
541 |
+
use_linear_projection=use_linear_projection,
|
542 |
+
only_cross_attention=only_cross_attention[i],
|
543 |
+
upcast_attention=upcast_attention,
|
544 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
545 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
546 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
547 |
+
cross_attention_norm=cross_attention_norm,
|
548 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
549 |
+
)
|
550 |
+
self.up_blocks.append(up_block)
|
551 |
+
prev_output_channel = output_channel
|
552 |
+
|
553 |
+
# out
|
554 |
+
if norm_num_groups is not None:
|
555 |
+
self.conv_norm_out = nn.GroupNorm(
|
556 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
557 |
+
)
|
558 |
+
|
559 |
+
self.conv_act = get_activation(act_fn)
|
560 |
+
|
561 |
+
else:
|
562 |
+
self.conv_norm_out = None
|
563 |
+
self.conv_act = None
|
564 |
+
|
565 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
566 |
+
self.conv_out = nn.Conv2d(
|
567 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
568 |
+
)
|
569 |
+
|
570 |
+
@property
|
571 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
572 |
+
r"""
|
573 |
+
Returns:
|
574 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
575 |
+
indexed by its weight name.
|
576 |
+
"""
|
577 |
+
# set recursively
|
578 |
+
processors = {}
|
579 |
+
|
580 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
581 |
+
if hasattr(module, "set_processor"):
|
582 |
+
processors[f"{name}.processor"] = module.processor
|
583 |
+
|
584 |
+
for sub_name, child in module.named_children():
|
585 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
586 |
+
|
587 |
+
return processors
|
588 |
+
|
589 |
+
for name, module in self.named_children():
|
590 |
+
fn_recursive_add_processors(name, module, processors)
|
591 |
+
|
592 |
+
return processors
|
593 |
+
|
594 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
595 |
+
r"""
|
596 |
+
Sets the attention processor to use to compute attention.
|
597 |
+
|
598 |
+
Parameters:
|
599 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
600 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
601 |
+
for **all** `Attention` layers.
|
602 |
+
|
603 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
604 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
605 |
+
|
606 |
+
"""
|
607 |
+
count = len(self.attn_processors.keys())
|
608 |
+
|
609 |
+
if isinstance(processor, dict) and len(processor) != count:
|
610 |
+
raise ValueError(
|
611 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
612 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
613 |
+
)
|
614 |
+
|
615 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
616 |
+
if hasattr(module, "set_processor"):
|
617 |
+
if not isinstance(processor, dict):
|
618 |
+
module.set_processor(processor)
|
619 |
+
else:
|
620 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
621 |
+
|
622 |
+
for sub_name, child in module.named_children():
|
623 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
624 |
+
|
625 |
+
for name, module in self.named_children():
|
626 |
+
fn_recursive_attn_processor(name, module, processor)
|
627 |
+
|
628 |
+
def set_default_attn_processor(self):
|
629 |
+
"""
|
630 |
+
Disables custom attention processors and sets the default attention implementation.
|
631 |
+
"""
|
632 |
+
self.set_attn_processor(AttnProcessor())
|
633 |
+
|
634 |
+
def set_attention_slice(self, slice_size):
|
635 |
+
r"""
|
636 |
+
Enable sliced attention computation.
|
637 |
+
|
638 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
639 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
640 |
+
|
641 |
+
Args:
|
642 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
643 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
644 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
645 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
646 |
+
must be a multiple of `slice_size`.
|
647 |
+
"""
|
648 |
+
sliceable_head_dims = []
|
649 |
+
|
650 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
651 |
+
if hasattr(module, "set_attention_slice"):
|
652 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
653 |
+
|
654 |
+
for child in module.children():
|
655 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
656 |
+
|
657 |
+
# retrieve number of attention layers
|
658 |
+
for module in self.children():
|
659 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
660 |
+
|
661 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
662 |
+
|
663 |
+
if slice_size == "auto":
|
664 |
+
# half the attention head size is usually a good trade-off between
|
665 |
+
# speed and memory
|
666 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
667 |
+
elif slice_size == "max":
|
668 |
+
# make smallest slice possible
|
669 |
+
slice_size = num_sliceable_layers * [1]
|
670 |
+
|
671 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
672 |
+
|
673 |
+
if len(slice_size) != len(sliceable_head_dims):
|
674 |
+
raise ValueError(
|
675 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
676 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
677 |
+
)
|
678 |
+
|
679 |
+
for i in range(len(slice_size)):
|
680 |
+
size = slice_size[i]
|
681 |
+
dim = sliceable_head_dims[i]
|
682 |
+
if size is not None and size > dim:
|
683 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
684 |
+
|
685 |
+
# Recursively walk through all the children.
|
686 |
+
# Any children which exposes the set_attention_slice method
|
687 |
+
# gets the message
|
688 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
689 |
+
if hasattr(module, "set_attention_slice"):
|
690 |
+
module.set_attention_slice(slice_size.pop())
|
691 |
+
|
692 |
+
for child in module.children():
|
693 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
694 |
+
|
695 |
+
reversed_slice_size = list(reversed(slice_size))
|
696 |
+
for module in self.children():
|
697 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
698 |
+
|
699 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
700 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
701 |
+
module.gradient_checkpointing = value
|
702 |
+
|
703 |
+
def forward(
|
704 |
+
self,
|
705 |
+
sample: torch.FloatTensor,
|
706 |
+
timestep: Union[torch.Tensor, float, int],
|
707 |
+
encoder_hidden_states: torch.Tensor,
|
708 |
+
class_labels: Optional[torch.Tensor] = None,
|
709 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
710 |
+
attention_mask: Optional[torch.Tensor] = None,
|
711 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
712 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
713 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
714 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
715 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
716 |
+
return_dict: bool = True,
|
717 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
718 |
+
r"""
|
719 |
+
The [`UNet2DConditionModel`] forward method.
|
720 |
+
|
721 |
+
Args:
|
722 |
+
sample (`torch.FloatTensor`):
|
723 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
724 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
725 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
726 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
727 |
+
encoder_attention_mask (`torch.Tensor`):
|
728 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
729 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
730 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
731 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
732 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
733 |
+
tuple.
|
734 |
+
cross_attention_kwargs (`dict`, *optional*):
|
735 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
736 |
+
added_cond_kwargs: (`dict`, *optional*):
|
737 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
738 |
+
are passed along to the UNet blocks.
|
739 |
+
|
740 |
+
Returns:
|
741 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
742 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
743 |
+
a `tuple` is returned where the first element is the sample tensor.
|
744 |
+
"""
|
745 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
746 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
747 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
748 |
+
# on the fly if necessary.
|
749 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
750 |
+
|
751 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
752 |
+
forward_upsample_size = False
|
753 |
+
upsample_size = None
|
754 |
+
|
755 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
756 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
757 |
+
forward_upsample_size = True
|
758 |
+
|
759 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
760 |
+
# expects mask of shape:
|
761 |
+
# [batch, key_tokens]
|
762 |
+
# adds singleton query_tokens dimension:
|
763 |
+
# [batch, 1, key_tokens]
|
764 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
765 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
766 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
767 |
+
if attention_mask is not None:
|
768 |
+
# assume that mask is expressed as:
|
769 |
+
# (1 = keep, 0 = discard)
|
770 |
+
# convert mask into a bias that can be added to attention scores:
|
771 |
+
# (keep = +0, discard = -10000.0)
|
772 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
773 |
+
attention_mask = attention_mask.unsqueeze(1)
|
774 |
+
|
775 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
776 |
+
if encoder_attention_mask is not None:
|
777 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
778 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
779 |
+
|
780 |
+
# 0. center input if necessary
|
781 |
+
if self.config.center_input_sample:
|
782 |
+
sample = 2 * sample - 1.0
|
783 |
+
|
784 |
+
# 1. time
|
785 |
+
timesteps = timestep
|
786 |
+
if not torch.is_tensor(timesteps):
|
787 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
788 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
789 |
+
is_mps = sample.device.type == "mps"
|
790 |
+
if isinstance(timestep, float):
|
791 |
+
dtype = torch.float32 if is_mps else torch.float64
|
792 |
+
else:
|
793 |
+
dtype = torch.int32 if is_mps else torch.int64
|
794 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
795 |
+
elif len(timesteps.shape) == 0:
|
796 |
+
timesteps = timesteps[None].to(sample.device)
|
797 |
+
|
798 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
799 |
+
timesteps = timesteps.expand(sample.shape[0])
|
800 |
+
|
801 |
+
t_emb = self.time_proj(timesteps)
|
802 |
+
|
803 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
804 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
805 |
+
# there might be better ways to encapsulate this.
|
806 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
807 |
+
|
808 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
809 |
+
aug_emb = None
|
810 |
+
|
811 |
+
if self.class_embedding is not None:
|
812 |
+
if class_labels is None:
|
813 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
814 |
+
|
815 |
+
if self.config.class_embed_type == "timestep":
|
816 |
+
class_labels = self.time_proj(class_labels)
|
817 |
+
|
818 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
819 |
+
# there might be better ways to encapsulate this.
|
820 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
821 |
+
|
822 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
823 |
+
|
824 |
+
if self.config.class_embeddings_concat:
|
825 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
826 |
+
else:
|
827 |
+
emb = emb + class_emb
|
828 |
+
|
829 |
+
if self.config.addition_embed_type == "text":
|
830 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
831 |
+
elif self.config.addition_embed_type == "text_image":
|
832 |
+
# Kandinsky 2.1 - style
|
833 |
+
if "image_embeds" not in added_cond_kwargs:
|
834 |
+
raise ValueError(
|
835 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
836 |
+
)
|
837 |
+
|
838 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
839 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
840 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
841 |
+
elif self.config.addition_embed_type == "text_time":
|
842 |
+
if "text_embeds" not in added_cond_kwargs:
|
843 |
+
raise ValueError(
|
844 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
845 |
+
)
|
846 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
847 |
+
if "time_ids" not in added_cond_kwargs:
|
848 |
+
raise ValueError(
|
849 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
850 |
+
)
|
851 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
852 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
853 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
854 |
+
|
855 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
856 |
+
add_embeds = add_embeds.to(emb.dtype)
|
857 |
+
aug_emb = self.add_embedding(add_embeds)
|
858 |
+
elif self.config.addition_embed_type == "image":
|
859 |
+
# Kandinsky 2.2 - style
|
860 |
+
if "image_embeds" not in added_cond_kwargs:
|
861 |
+
raise ValueError(
|
862 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
863 |
+
)
|
864 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
865 |
+
aug_emb = self.add_embedding(image_embs)
|
866 |
+
elif self.config.addition_embed_type == "image_hint":
|
867 |
+
# Kandinsky 2.2 - style
|
868 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
869 |
+
raise ValueError(
|
870 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
871 |
+
)
|
872 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
873 |
+
hint = added_cond_kwargs.get("hint")
|
874 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
875 |
+
sample = torch.cat([sample, hint], dim=1)
|
876 |
+
|
877 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
878 |
+
|
879 |
+
if self.time_embed_act is not None:
|
880 |
+
emb = self.time_embed_act(emb)
|
881 |
+
|
882 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
883 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
884 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
885 |
+
# Kadinsky 2.1 - style
|
886 |
+
if "image_embeds" not in added_cond_kwargs:
|
887 |
+
raise ValueError(
|
888 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
889 |
+
)
|
890 |
+
|
891 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
892 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
893 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
894 |
+
# Kandinsky 2.2 - style
|
895 |
+
if "image_embeds" not in added_cond_kwargs:
|
896 |
+
raise ValueError(
|
897 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
898 |
+
)
|
899 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
900 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
901 |
+
# 2. pre-process
|
902 |
+
sample = self.conv_in(sample)
|
903 |
+
|
904 |
+
# 3. down
|
905 |
+
down_block_res_samples = (sample,)
|
906 |
+
for downsample_block in self.down_blocks:
|
907 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
908 |
+
sample, res_samples = downsample_block(
|
909 |
+
hidden_states=sample,
|
910 |
+
temb=emb,
|
911 |
+
encoder_hidden_states=encoder_hidden_states,
|
912 |
+
attention_mask=attention_mask,
|
913 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
914 |
+
encoder_attention_mask=encoder_attention_mask,
|
915 |
+
)
|
916 |
+
else:
|
917 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
918 |
+
|
919 |
+
down_block_res_samples += res_samples
|
920 |
+
|
921 |
+
if down_block_additional_residuals is not None:
|
922 |
+
new_down_block_res_samples = ()
|
923 |
+
|
924 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
925 |
+
down_block_res_samples, down_block_additional_residuals
|
926 |
+
):
|
927 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
928 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
929 |
+
|
930 |
+
down_block_res_samples = new_down_block_res_samples
|
931 |
+
|
932 |
+
# 4. mid
|
933 |
+
if self.mid_block is not None:
|
934 |
+
sample = self.mid_block(
|
935 |
+
sample,
|
936 |
+
emb,
|
937 |
+
encoder_hidden_states=encoder_hidden_states,
|
938 |
+
attention_mask=attention_mask,
|
939 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
940 |
+
encoder_attention_mask=encoder_attention_mask,
|
941 |
+
)
|
942 |
+
|
943 |
+
if mid_block_additional_residual is not None:
|
944 |
+
sample = sample + mid_block_additional_residual
|
945 |
+
|
946 |
+
# 5. up
|
947 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
948 |
+
is_final_block = i == len(self.up_blocks) - 1
|
949 |
+
|
950 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
951 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
952 |
+
|
953 |
+
# if we have not reached the final block and need to forward the
|
954 |
+
# upsample size, we do it here
|
955 |
+
if not is_final_block and forward_upsample_size:
|
956 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
957 |
+
|
958 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
959 |
+
sample = upsample_block(
|
960 |
+
hidden_states=sample,
|
961 |
+
temb=emb,
|
962 |
+
res_hidden_states_tuple=res_samples,
|
963 |
+
encoder_hidden_states=encoder_hidden_states,
|
964 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
965 |
+
upsample_size=upsample_size,
|
966 |
+
attention_mask=attention_mask,
|
967 |
+
encoder_attention_mask=encoder_attention_mask,
|
968 |
+
)
|
969 |
+
else:
|
970 |
+
sample = upsample_block(
|
971 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
972 |
+
)
|
973 |
+
|
974 |
+
# 6. post-process
|
975 |
+
if self.conv_norm_out:
|
976 |
+
sample = self.conv_norm_out(sample)
|
977 |
+
sample = self.conv_act(sample)
|
978 |
+
sample = self.conv_out(sample)
|
979 |
+
|
980 |
+
if not return_dict:
|
981 |
+
return (sample,)
|
982 |
+
|
983 |
+
return UNet2DConditionOutput(sample=sample)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
2 |
+
torch==1.13.1
|
3 |
+
torchvision==0.14.1
|
4 |
+
diffusers==0.18.2
|
5 |
+
transformers==4.27.0
|
6 |
+
safetensors==0.3.1
|
7 |
+
invisible_watermark==0.2.0
|
8 |
+
numpy==1.24.3
|
9 |
+
seaborn==0.12.2
|
10 |
+
accelerate==0.16.0
|
11 |
+
scikit-learn==1.1.3
|
rich-text-to-json-iframe.html
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<title>Rich Text to JSON</title>
|
6 |
+
<link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
|
7 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
|
8 |
+
<link rel="stylesheet" type="text/css"
|
9 |
+
href="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.css">
|
10 |
+
<link rel="stylesheet"
|
11 |
+
href='https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap'>
|
12 |
+
<style>
|
13 |
+
html,
|
14 |
+
body {
|
15 |
+
background-color: white;
|
16 |
+
margin: 0;
|
17 |
+
}
|
18 |
+
|
19 |
+
/* Set default font-family */
|
20 |
+
.ql-snow .ql-tooltip::before {
|
21 |
+
content: "Footnote";
|
22 |
+
line-height: 26px;
|
23 |
+
margin-right: 8px;
|
24 |
+
}
|
25 |
+
|
26 |
+
.ql-snow .ql-tooltip[data-mode=link]::before {
|
27 |
+
content: "Enter footnote:";
|
28 |
+
}
|
29 |
+
|
30 |
+
.row {
|
31 |
+
margin-top: 15px;
|
32 |
+
margin-left: 0px;
|
33 |
+
margin-bottom: 15px;
|
34 |
+
}
|
35 |
+
|
36 |
+
.btn-primary {
|
37 |
+
color: #ffffff;
|
38 |
+
background-color: #2780e3;
|
39 |
+
border-color: #2780e3;
|
40 |
+
}
|
41 |
+
|
42 |
+
.btn-primary:hover {
|
43 |
+
color: #ffffff;
|
44 |
+
background-color: #1967be;
|
45 |
+
border-color: #1862b5;
|
46 |
+
}
|
47 |
+
|
48 |
+
.btn {
|
49 |
+
display: inline-block;
|
50 |
+
margin-bottom: 0;
|
51 |
+
font-weight: normal;
|
52 |
+
text-align: center;
|
53 |
+
vertical-align: middle;
|
54 |
+
touch-action: manipulation;
|
55 |
+
cursor: pointer;
|
56 |
+
background-image: none;
|
57 |
+
border: 1px solid transparent;
|
58 |
+
white-space: nowrap;
|
59 |
+
padding: 10px 18px;
|
60 |
+
font-size: 15px;
|
61 |
+
line-height: 1.42857143;
|
62 |
+
border-radius: 0;
|
63 |
+
user-select: none;
|
64 |
+
}
|
65 |
+
|
66 |
+
#standalone-container {
|
67 |
+
width: 100%;
|
68 |
+
background-color: #ffffff;
|
69 |
+
}
|
70 |
+
|
71 |
+
#editor-container {
|
72 |
+
font-family: "Aref Ruqaa";
|
73 |
+
font-size: 18px;
|
74 |
+
height: 250px;
|
75 |
+
width: 100%;
|
76 |
+
}
|
77 |
+
|
78 |
+
#toolbar-container {
|
79 |
+
font-family: "Aref Ruqaa";
|
80 |
+
display: flex;
|
81 |
+
flex-wrap: wrap;
|
82 |
+
}
|
83 |
+
|
84 |
+
#json-container {
|
85 |
+
max-width: 720px;
|
86 |
+
}
|
87 |
+
|
88 |
+
/* Set dropdown font-families */
|
89 |
+
#toolbar-container .ql-font span[data-label="Base"]::before {
|
90 |
+
font-family: "Aref Ruqaa";
|
91 |
+
}
|
92 |
+
|
93 |
+
#toolbar-container .ql-font span[data-label="Claude Monet"]::before {
|
94 |
+
font-family: "Mirza";
|
95 |
+
}
|
96 |
+
|
97 |
+
#toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
|
98 |
+
font-family: "Roboto";
|
99 |
+
}
|
100 |
+
|
101 |
+
#toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
|
102 |
+
font-family: "Comic Sans MS";
|
103 |
+
}
|
104 |
+
|
105 |
+
#toolbar-container .ql-font span[data-label="Pop Art"]::before {
|
106 |
+
font-family: "sofia";
|
107 |
+
}
|
108 |
+
|
109 |
+
#toolbar-container .ql-font span[data-label="Van Gogh"]::before {
|
110 |
+
font-family: "slabo 27px";
|
111 |
+
}
|
112 |
+
|
113 |
+
#toolbar-container .ql-font span[data-label="Pixel Art"]::before {
|
114 |
+
font-family: "inconsolata";
|
115 |
+
}
|
116 |
+
|
117 |
+
#toolbar-container .ql-font span[data-label="Rembrandt"]::before {
|
118 |
+
font-family: "ubuntu";
|
119 |
+
}
|
120 |
+
|
121 |
+
#toolbar-container .ql-font span[data-label="Cubism"]::before {
|
122 |
+
font-family: "Akronim";
|
123 |
+
}
|
124 |
+
|
125 |
+
#toolbar-container .ql-font span[data-label="Neon Art"]::before {
|
126 |
+
font-family: "Monoton";
|
127 |
+
}
|
128 |
+
|
129 |
+
/* Set content font-families */
|
130 |
+
.ql-font-mirza {
|
131 |
+
font-family: "Mirza";
|
132 |
+
}
|
133 |
+
|
134 |
+
.ql-font-roboto {
|
135 |
+
font-family: "Roboto";
|
136 |
+
}
|
137 |
+
|
138 |
+
.ql-font-cursive {
|
139 |
+
font-family: "Comic Sans MS";
|
140 |
+
}
|
141 |
+
|
142 |
+
.ql-font-sofia {
|
143 |
+
font-family: "sofia";
|
144 |
+
}
|
145 |
+
|
146 |
+
.ql-font-slabo {
|
147 |
+
font-family: "slabo 27px";
|
148 |
+
}
|
149 |
+
|
150 |
+
.ql-font-inconsolata {
|
151 |
+
font-family: "inconsolata";
|
152 |
+
}
|
153 |
+
|
154 |
+
.ql-font-ubuntu {
|
155 |
+
font-family: "ubuntu";
|
156 |
+
}
|
157 |
+
|
158 |
+
.ql-font-Akronim {
|
159 |
+
font-family: "Akronim";
|
160 |
+
}
|
161 |
+
|
162 |
+
.ql-font-Monoton {
|
163 |
+
font-family: "Monoton";
|
164 |
+
}
|
165 |
+
|
166 |
+
.ql-color .ql-picker-options [data-value=Color-Picker] {
|
167 |
+
background: none !important;
|
168 |
+
width: 100% !important;
|
169 |
+
height: 20px !important;
|
170 |
+
text-align: center;
|
171 |
+
}
|
172 |
+
|
173 |
+
.ql-color .ql-picker-options [data-value=Color-Picker]:before {
|
174 |
+
content: 'Color Picker';
|
175 |
+
}
|
176 |
+
|
177 |
+
.ql-color .ql-picker-options [data-value=Color-Picker]:hover {
|
178 |
+
border-color: transparent !important;
|
179 |
+
}
|
180 |
+
</style>
|
181 |
+
</head>
|
182 |
+
|
183 |
+
<body>
|
184 |
+
<div id="standalone-container">
|
185 |
+
<div id="toolbar-container">
|
186 |
+
<span class="ql-formats">
|
187 |
+
<select class="ql-font">
|
188 |
+
<option selected>Base</option>
|
189 |
+
<option value="mirza">Claude Monet</option>
|
190 |
+
<option value="roboto">Ukiyoe</option>
|
191 |
+
<option value="cursive">Cyber Punk</option>
|
192 |
+
<option value="sofia">Pop Art</option>
|
193 |
+
<option value="slabo">Van Gogh</option>
|
194 |
+
<option value="inconsolata">Pixel Art</option>
|
195 |
+
<option value="ubuntu">Rembrandt</option>
|
196 |
+
<option value="Akronim">Cubism</option>
|
197 |
+
<option value="Monoton">Neon Art</option>
|
198 |
+
</select>
|
199 |
+
<select class="ql-size">
|
200 |
+
<option value="18px">Small</option>
|
201 |
+
<option selected>Normal</option>
|
202 |
+
<option value="32px">Large</option>
|
203 |
+
<option value="50px">Huge</option>
|
204 |
+
</select>
|
205 |
+
</span>
|
206 |
+
<span class="ql-formats">
|
207 |
+
<button class="ql-strike"></button>
|
208 |
+
</span>
|
209 |
+
<!-- <span class="ql-formats">
|
210 |
+
<button class="ql-bold"></button>
|
211 |
+
<button class="ql-italic"></button>
|
212 |
+
<button class="ql-underline"></button>
|
213 |
+
</span> -->
|
214 |
+
<span class="ql-formats">
|
215 |
+
<select class="ql-color">
|
216 |
+
<option value="Color-Picker"></option>
|
217 |
+
</select>
|
218 |
+
<!-- <select class="ql-background"></select> -->
|
219 |
+
</span>
|
220 |
+
<!-- <span class="ql-formats">
|
221 |
+
<button class="ql-script" value="sub"></button>
|
222 |
+
<button class="ql-script" value="super"></button>
|
223 |
+
</span>
|
224 |
+
<span class="ql-formats">
|
225 |
+
<button class="ql-header" value="1"></button>
|
226 |
+
<button class="ql-header" value="2"></button>
|
227 |
+
<button class="ql-blockquote"></button>
|
228 |
+
<button class="ql-code-block"></button>
|
229 |
+
</span>
|
230 |
+
<span class="ql-formats">
|
231 |
+
<button class="ql-list" value="ordered"></button>
|
232 |
+
<button class="ql-list" value="bullet"></button>
|
233 |
+
<button class="ql-indent" value="-1"></button>
|
234 |
+
<button class="ql-indent" value="+1"></button>
|
235 |
+
</span>
|
236 |
+
<span class="ql-formats">
|
237 |
+
<button class="ql-direction" value="rtl"></button>
|
238 |
+
<select class="ql-align"></select>
|
239 |
+
</span>
|
240 |
+
<span class="ql-formats">
|
241 |
+
<button class="ql-link"></button>
|
242 |
+
<button class="ql-image"></button>
|
243 |
+
<button class="ql-video"></button>
|
244 |
+
<button class="ql-formula"></button>
|
245 |
+
</span> -->
|
246 |
+
<span class="ql-formats">
|
247 |
+
<button class="ql-link"></button>
|
248 |
+
</span>
|
249 |
+
<span class="ql-formats">
|
250 |
+
<button class="ql-clean"></button>
|
251 |
+
</span>
|
252 |
+
</div>
|
253 |
+
<div id="editor-container" style="height:300px;"></div>
|
254 |
+
</div>
|
255 |
+
<script src="https://cdn.quilljs.com/1.3.6/quill.min.js"></script>
|
256 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.1.0/jquery.min.js"></script>
|
257 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.js"></script>
|
258 |
+
<script>
|
259 |
+
|
260 |
+
// Register the customs format with Quill
|
261 |
+
const Font = Quill.import('formats/font');
|
262 |
+
Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
|
263 |
+
const Link = Quill.import('formats/link');
|
264 |
+
Link.sanitize = function (url) {
|
265 |
+
// modify url if desired
|
266 |
+
return url;
|
267 |
+
}
|
268 |
+
const SizeStyle = Quill.import('attributors/style/size');
|
269 |
+
SizeStyle.whitelist = ['10px', '18px', '20px', '32px', '50px', '60px', '64px', '70px'];
|
270 |
+
Quill.register(SizeStyle, true);
|
271 |
+
Quill.register(Link, true);
|
272 |
+
Quill.register(Font, true);
|
273 |
+
const icons = Quill.import('ui/icons');
|
274 |
+
icons['link'] = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
|
275 |
+
const quill = new Quill('#editor-container', {
|
276 |
+
modules: {
|
277 |
+
toolbar: {
|
278 |
+
container: '#toolbar-container',
|
279 |
+
},
|
280 |
+
},
|
281 |
+
theme: 'snow'
|
282 |
+
});
|
283 |
+
var toolbar = quill.getModule('toolbar');
|
284 |
+
$(toolbar.container).find('.ql-color').spectrum({
|
285 |
+
preferredFormat: "rgb",
|
286 |
+
showInput: true,
|
287 |
+
showInitial: true,
|
288 |
+
showPalette: true,
|
289 |
+
showSelectionPalette: true,
|
290 |
+
palette: [
|
291 |
+
["#000", "#444", "#666", "#999", "#ccc", "#eee", "#f3f3f3", "#fff"],
|
292 |
+
["#f00", "#f90", "#ff0", "#0f0", "#0ff", "#00f", "#90f", "#f0f"],
|
293 |
+
["#ea9999", "#f9cb9c", "#ffe599", "#b6d7a8", "#a2c4c9", "#9fc5e8", "#b4a7d6", "#d5a6bd"],
|
294 |
+
["#e06666", "#f6b26b", "#ffd966", "#93c47d", "#76a5af", "#6fa8dc", "#8e7cc3", "#c27ba0"],
|
295 |
+
["#c00", "#e69138", "#f1c232", "#6aa84f", "#45818e", "#3d85c6", "#674ea7", "#a64d79"],
|
296 |
+
["#900", "#b45f06", "#bf9000", "#38761d", "#134f5c", "#0b5394", "#351c75", "#741b47"],
|
297 |
+
["#600", "#783f04", "#7f6000", "#274e13", "#0c343d", "#073763", "#20124d", "#4c1130"]
|
298 |
+
],
|
299 |
+
change: function (color) {
|
300 |
+
var value = color.toHexString();
|
301 |
+
quill.format('color', value);
|
302 |
+
}
|
303 |
+
});
|
304 |
+
|
305 |
+
quill.on('text-change', () => {
|
306 |
+
// keep qull data inside _data to communicate with Gradio
|
307 |
+
document.body._data = quill.getContents()
|
308 |
+
})
|
309 |
+
function setQuillContents(content) {
|
310 |
+
quill.setContents(content);
|
311 |
+
document.body._data = quill.getContents();
|
312 |
+
}
|
313 |
+
document.body.setQuillContents = setQuillContents
|
314 |
+
</script>
|
315 |
+
<script src="https://unpkg.com/@popperjs/core@2/dist/umd/popper.min.js"></script>
|
316 |
+
<script src="https://unpkg.com/tippy.js@6/dist/tippy-bundle.umd.js"></script>
|
317 |
+
<script>
|
318 |
+
// With the above scripts loaded, you can call `tippy()` with a CSS
|
319 |
+
// selector and a `content` prop:
|
320 |
+
tippy('.ql-font', {
|
321 |
+
content: 'Add a style to the token',
|
322 |
+
});
|
323 |
+
tippy('.ql-size', {
|
324 |
+
content: 'Reweight the token',
|
325 |
+
});
|
326 |
+
tippy('.ql-color', {
|
327 |
+
content: 'Pick a color for the token',
|
328 |
+
});
|
329 |
+
tippy('.ql-link', {
|
330 |
+
content: 'Clarify the token',
|
331 |
+
});
|
332 |
+
tippy('.ql-strike', {
|
333 |
+
content: 'Change the token weight to be negative',
|
334 |
+
});
|
335 |
+
tippy('.ql-clean', {
|
336 |
+
content: 'Remove all the formats',
|
337 |
+
});
|
338 |
+
</script>
|
339 |
+
</body>
|
340 |
+
|
341 |
+
</html>
|
rich-text-to-json.js
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class RichTextEditor extends HTMLElement {
|
2 |
+
constructor() {
|
3 |
+
super();
|
4 |
+
this.loadExternalScripts();
|
5 |
+
this.attachShadow({ mode: 'open' });
|
6 |
+
this.shadowRoot.innerHTML = `
|
7 |
+
${RichTextEditor.header()}
|
8 |
+
${RichTextEditor.template()}
|
9 |
+
`;
|
10 |
+
}
|
11 |
+
connectedCallback() {
|
12 |
+
this.myQuill = this.mountQuill();
|
13 |
+
}
|
14 |
+
loadExternalScripts() {
|
15 |
+
const links = ["https://cdn.quilljs.com/1.3.6/quill.snow.css", "https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css", "https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap"]
|
16 |
+
links.forEach(link => {
|
17 |
+
const css = document.createElement("link");
|
18 |
+
css.href = link;
|
19 |
+
css.rel = "stylesheet"
|
20 |
+
document.head.appendChild(css);
|
21 |
+
})
|
22 |
+
|
23 |
+
}
|
24 |
+
static template() {
|
25 |
+
return `
|
26 |
+
<div id="standalone-container">
|
27 |
+
<div id="toolbar-container">
|
28 |
+
<span class="ql-formats">
|
29 |
+
<select class="ql-font">
|
30 |
+
<option selected>Base</option>
|
31 |
+
<option value="mirza">Claude Monet</option>
|
32 |
+
<option value="roboto">Ukiyoe</option>
|
33 |
+
<option value="cursive">Cyber Punk</option>
|
34 |
+
<option value="sofia">Pop Art</option>
|
35 |
+
<option value="slabo">Van Gogh</option>
|
36 |
+
<option value="inconsolata">Pixel Art</option>
|
37 |
+
<option value="ubuntu">Rembrandt</option>
|
38 |
+
<option value="Akronim">Cubism</option>
|
39 |
+
<option value="Monoton">Neon Art</option>
|
40 |
+
</select>
|
41 |
+
<select class="ql-size">
|
42 |
+
<option value="18px">Small</option>
|
43 |
+
<option selected>Normal</option>
|
44 |
+
<option value="32px">Large</option>
|
45 |
+
<option value="50px">Huge</option>
|
46 |
+
</select>
|
47 |
+
</span>
|
48 |
+
<span class="ql-formats">
|
49 |
+
<button class="ql-strike"></button>
|
50 |
+
</span>
|
51 |
+
<!-- <span class="ql-formats">
|
52 |
+
<button class="ql-bold"></button>
|
53 |
+
<button class="ql-italic"></button>
|
54 |
+
<button class="ql-underline"></button>
|
55 |
+
</span> -->
|
56 |
+
<span class="ql-formats">
|
57 |
+
<select class="ql-color"></select>
|
58 |
+
<!-- <select class="ql-background"></select> -->
|
59 |
+
</span>
|
60 |
+
<!-- <span class="ql-formats">
|
61 |
+
<button class="ql-script" value="sub"></button>
|
62 |
+
<button class="ql-script" value="super"></button>
|
63 |
+
</span>
|
64 |
+
<span class="ql-formats">
|
65 |
+
<button class="ql-header" value="1"></button>
|
66 |
+
<button class="ql-header" value="2"></button>
|
67 |
+
<button class="ql-blockquote"></button>
|
68 |
+
<button class="ql-code-block"></button>
|
69 |
+
</span>
|
70 |
+
<span class="ql-formats">
|
71 |
+
<button class="ql-list" value="ordered"></button>
|
72 |
+
<button class="ql-list" value="bullet"></button>
|
73 |
+
<button class="ql-indent" value="-1"></button>
|
74 |
+
<button class="ql-indent" value="+1"></button>
|
75 |
+
</span>
|
76 |
+
<span class="ql-formats">
|
77 |
+
<button class="ql-direction" value="rtl"></button>
|
78 |
+
<select class="ql-align"></select>
|
79 |
+
</span>
|
80 |
+
<span class="ql-formats">
|
81 |
+
<button class="ql-link"></button>
|
82 |
+
<button class="ql-image"></button>
|
83 |
+
<button class="ql-video"></button>
|
84 |
+
<button class="ql-formula"></button>
|
85 |
+
</span> -->
|
86 |
+
<span class="ql-formats">
|
87 |
+
<button class="ql-link"></button>
|
88 |
+
</span>
|
89 |
+
<span class="ql-formats">
|
90 |
+
<button class="ql-clean"></button>
|
91 |
+
</span>
|
92 |
+
</div>
|
93 |
+
<div id="editor-container"></div>
|
94 |
+
</div>
|
95 |
+
`;
|
96 |
+
}
|
97 |
+
|
98 |
+
static header() {
|
99 |
+
return `
|
100 |
+
<link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
|
101 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
|
102 |
+
<style>
|
103 |
+
/* Set default font-family */
|
104 |
+
.ql-snow .ql-tooltip::before {
|
105 |
+
content: "Footnote";
|
106 |
+
line-height: 26px;
|
107 |
+
margin-right: 8px;
|
108 |
+
}
|
109 |
+
|
110 |
+
.ql-snow .ql-tooltip[data-mode=link]::before {
|
111 |
+
content: "Enter footnote:";
|
112 |
+
}
|
113 |
+
|
114 |
+
.row {
|
115 |
+
margin-top: 15px;
|
116 |
+
margin-left: 0px;
|
117 |
+
margin-bottom: 15px;
|
118 |
+
}
|
119 |
+
|
120 |
+
.btn-primary {
|
121 |
+
color: #ffffff;
|
122 |
+
background-color: #2780e3;
|
123 |
+
border-color: #2780e3;
|
124 |
+
}
|
125 |
+
|
126 |
+
.btn-primary:hover {
|
127 |
+
color: #ffffff;
|
128 |
+
background-color: #1967be;
|
129 |
+
border-color: #1862b5;
|
130 |
+
}
|
131 |
+
|
132 |
+
.btn {
|
133 |
+
display: inline-block;
|
134 |
+
margin-bottom: 0;
|
135 |
+
font-weight: normal;
|
136 |
+
text-align: center;
|
137 |
+
vertical-align: middle;
|
138 |
+
touch-action: manipulation;
|
139 |
+
cursor: pointer;
|
140 |
+
background-image: none;
|
141 |
+
border: 1px solid transparent;
|
142 |
+
white-space: nowrap;
|
143 |
+
padding: 10px 18px;
|
144 |
+
font-size: 15px;
|
145 |
+
line-height: 1.42857143;
|
146 |
+
border-radius: 0;
|
147 |
+
user-select: none;
|
148 |
+
}
|
149 |
+
|
150 |
+
#standalone-container {
|
151 |
+
position: relative;
|
152 |
+
max-width: 720px;
|
153 |
+
background-color: #ffffff;
|
154 |
+
color: black !important;
|
155 |
+
z-index: 1000;
|
156 |
+
}
|
157 |
+
|
158 |
+
#editor-container {
|
159 |
+
font-family: "Aref Ruqaa";
|
160 |
+
font-size: 18px;
|
161 |
+
height: 250px;
|
162 |
+
}
|
163 |
+
|
164 |
+
#toolbar-container {
|
165 |
+
font-family: "Aref Ruqaa";
|
166 |
+
display: flex;
|
167 |
+
flex-wrap: wrap;
|
168 |
+
}
|
169 |
+
|
170 |
+
#json-container {
|
171 |
+
max-width: 720px;
|
172 |
+
}
|
173 |
+
|
174 |
+
/* Set dropdown font-families */
|
175 |
+
#toolbar-container .ql-font span[data-label="Base"]::before {
|
176 |
+
font-family: "Aref Ruqaa";
|
177 |
+
}
|
178 |
+
|
179 |
+
#toolbar-container .ql-font span[data-label="Claude Monet"]::before {
|
180 |
+
font-family: "Mirza";
|
181 |
+
}
|
182 |
+
|
183 |
+
#toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
|
184 |
+
font-family: "Roboto";
|
185 |
+
}
|
186 |
+
|
187 |
+
#toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
|
188 |
+
font-family: "Comic Sans MS";
|
189 |
+
}
|
190 |
+
|
191 |
+
#toolbar-container .ql-font span[data-label="Pop Art"]::before {
|
192 |
+
font-family: "sofia";
|
193 |
+
}
|
194 |
+
|
195 |
+
#toolbar-container .ql-font span[data-label="Van Gogh"]::before {
|
196 |
+
font-family: "slabo 27px";
|
197 |
+
}
|
198 |
+
|
199 |
+
#toolbar-container .ql-font span[data-label="Pixel Art"]::before {
|
200 |
+
font-family: "inconsolata";
|
201 |
+
}
|
202 |
+
|
203 |
+
#toolbar-container .ql-font span[data-label="Rembrandt"]::before {
|
204 |
+
font-family: "ubuntu";
|
205 |
+
}
|
206 |
+
|
207 |
+
#toolbar-container .ql-font span[data-label="Cubism"]::before {
|
208 |
+
font-family: "Akronim";
|
209 |
+
}
|
210 |
+
|
211 |
+
#toolbar-container .ql-font span[data-label="Neon Art"]::before {
|
212 |
+
font-family: "Monoton";
|
213 |
+
}
|
214 |
+
|
215 |
+
/* Set content font-families */
|
216 |
+
.ql-font-mirza {
|
217 |
+
font-family: "Mirza";
|
218 |
+
}
|
219 |
+
|
220 |
+
.ql-font-roboto {
|
221 |
+
font-family: "Roboto";
|
222 |
+
}
|
223 |
+
|
224 |
+
.ql-font-cursive {
|
225 |
+
font-family: "Comic Sans MS";
|
226 |
+
}
|
227 |
+
|
228 |
+
.ql-font-sofia {
|
229 |
+
font-family: "sofia";
|
230 |
+
}
|
231 |
+
|
232 |
+
.ql-font-slabo {
|
233 |
+
font-family: "slabo 27px";
|
234 |
+
}
|
235 |
+
|
236 |
+
.ql-font-inconsolata {
|
237 |
+
font-family: "inconsolata";
|
238 |
+
}
|
239 |
+
|
240 |
+
.ql-font-ubuntu {
|
241 |
+
font-family: "ubuntu";
|
242 |
+
}
|
243 |
+
|
244 |
+
.ql-font-Akronim {
|
245 |
+
font-family: "Akronim";
|
246 |
+
}
|
247 |
+
|
248 |
+
.ql-font-Monoton {
|
249 |
+
font-family: "Monoton";
|
250 |
+
}
|
251 |
+
</style>
|
252 |
+
`;
|
253 |
+
}
|
254 |
+
async mountQuill() {
|
255 |
+
// Register the customs format with Quill
|
256 |
+
const lib = await import("https://cdn.jsdelivr.net/npm/shadow-selection-polyfill");
|
257 |
+
const getRange = lib.getRange;
|
258 |
+
|
259 |
+
const Font = Quill.import('formats/font');
|
260 |
+
Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
|
261 |
+
const Link = Quill.import('formats/link');
|
262 |
+
Link.sanitize = function (url) {
|
263 |
+
// modify url if desired
|
264 |
+
return url;
|
265 |
+
}
|
266 |
+
const SizeStyle = Quill.import('attributors/style/size');
|
267 |
+
SizeStyle.whitelist = ['10px', '18px', '32px', '50px', '64px'];
|
268 |
+
Quill.register(SizeStyle, true);
|
269 |
+
Quill.register(Link, true);
|
270 |
+
Quill.register(Font, true);
|
271 |
+
const icons = Quill.import('ui/icons');
|
272 |
+
const icon = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
|
273 |
+
icons['link'] = icon;
|
274 |
+
const editorContainer = this.shadowRoot.querySelector('#editor-container')
|
275 |
+
const toolbarContainer = this.shadowRoot.querySelector('#toolbar-container')
|
276 |
+
const myQuill = new Quill(editorContainer, {
|
277 |
+
modules: {
|
278 |
+
toolbar: {
|
279 |
+
container: toolbarContainer,
|
280 |
+
},
|
281 |
+
},
|
282 |
+
theme: 'snow'
|
283 |
+
});
|
284 |
+
const normalizeNative = (nativeRange) => {
|
285 |
+
|
286 |
+
if (nativeRange) {
|
287 |
+
const range = nativeRange;
|
288 |
+
|
289 |
+
if (range.baseNode) {
|
290 |
+
range.startContainer = nativeRange.baseNode;
|
291 |
+
range.endContainer = nativeRange.focusNode;
|
292 |
+
range.startOffset = nativeRange.baseOffset;
|
293 |
+
range.endOffset = nativeRange.focusOffset;
|
294 |
+
|
295 |
+
if (range.endOffset < range.startOffset) {
|
296 |
+
range.startContainer = nativeRange.focusNode;
|
297 |
+
range.endContainer = nativeRange.baseNode;
|
298 |
+
range.startOffset = nativeRange.focusOffset;
|
299 |
+
range.endOffset = nativeRange.baseOffset;
|
300 |
+
}
|
301 |
+
}
|
302 |
+
|
303 |
+
if (range.startContainer) {
|
304 |
+
return {
|
305 |
+
start: { node: range.startContainer, offset: range.startOffset },
|
306 |
+
end: { node: range.endContainer, offset: range.endOffset },
|
307 |
+
native: range
|
308 |
+
};
|
309 |
+
}
|
310 |
+
}
|
311 |
+
|
312 |
+
return null
|
313 |
+
};
|
314 |
+
|
315 |
+
myQuill.selection.getNativeRange = () => {
|
316 |
+
|
317 |
+
const dom = myQuill.root.getRootNode();
|
318 |
+
const selection = getRange(dom);
|
319 |
+
const range = normalizeNative(selection);
|
320 |
+
|
321 |
+
return range;
|
322 |
+
};
|
323 |
+
let fromEditor = false;
|
324 |
+
editorContainer.addEventListener("pointerup", (e) => {
|
325 |
+
fromEditor = false;
|
326 |
+
});
|
327 |
+
editorContainer.addEventListener("pointerout", (e) => {
|
328 |
+
fromEditor = false;
|
329 |
+
});
|
330 |
+
editorContainer.addEventListener("pointerdown", (e) => {
|
331 |
+
fromEditor = true;
|
332 |
+
});
|
333 |
+
|
334 |
+
document.addEventListener("selectionchange", () => {
|
335 |
+
if (fromEditor) {
|
336 |
+
myQuill.selection.update()
|
337 |
+
}
|
338 |
+
});
|
339 |
+
|
340 |
+
|
341 |
+
myQuill.on('text-change', () => {
|
342 |
+
// keep qull data inside _data to communicate with Gradio
|
343 |
+
document.querySelector("#rich-text-root")._data = myQuill.getContents()
|
344 |
+
})
|
345 |
+
return myQuill
|
346 |
+
}
|
347 |
+
}
|
348 |
+
|
349 |
+
customElements.define('rich-text-editor', RichTextEditor);
|
share_btn.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
2 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
3 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
4 |
+
</svg>"""
|
5 |
+
|
6 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin" style="color: #ffffff;" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
7 |
+
|
8 |
+
share_js = """async () => {
|
9 |
+
async function uploadFile(file){
|
10 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
11 |
+
const response = await fetch(UPLOAD_URL, {
|
12 |
+
method: 'POST',
|
13 |
+
headers: {
|
14 |
+
'Content-Type': file.type,
|
15 |
+
'X-Requested-With': 'XMLHttpRequest',
|
16 |
+
},
|
17 |
+
body: file, /// <- File inherits from Blob
|
18 |
+
});
|
19 |
+
const url = await response.text();
|
20 |
+
return url;
|
21 |
+
}
|
22 |
+
async function getInputImageFile(imageEl){
|
23 |
+
const res = await fetch(imageEl.src);
|
24 |
+
const blob = await res.blob();
|
25 |
+
const imageId = Date.now();
|
26 |
+
const fileName = `rich-text-image-${{imageId}}.png`;
|
27 |
+
return new File([blob], fileName, { type: 'image/png'});
|
28 |
+
}
|
29 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
30 |
+
const richEl = document.getElementById("rich-text-root");
|
31 |
+
const data = richEl? richEl.contentDocument.body._data : {};
|
32 |
+
const text_input = JSON.stringify(data);
|
33 |
+
const negative_prompt = gradioEl.querySelector('#negative_prompt input').value;
|
34 |
+
const seed = gradioEl.querySelector('#seed input').value;
|
35 |
+
const richTextImg = gradioEl.querySelector('#rich-text-image img');
|
36 |
+
const plainTextImg = gradioEl.querySelector('#plain-text-image img');
|
37 |
+
const text_input_obj = JSON.parse(text_input);
|
38 |
+
const plain_prompt = text_input_obj.ops.map(e=> e.insert).join('');
|
39 |
+
const linkSrc = `https://huggingface.co/spaces/songweig/rich-text-to-image?prompt=${encodeURIComponent(text_input)}`;
|
40 |
+
|
41 |
+
const titleTxt = `RT2I: ${plain_prompt.slice(0, 50)}...`;
|
42 |
+
const shareBtnEl = gradioEl.querySelector('#share-btn');
|
43 |
+
const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
|
44 |
+
const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
|
45 |
+
if(!richTextImg){
|
46 |
+
return;
|
47 |
+
};
|
48 |
+
shareBtnEl.style.pointerEvents = 'none';
|
49 |
+
shareIconEl.style.display = 'none';
|
50 |
+
loadingIconEl.style.removeProperty('display');
|
51 |
+
|
52 |
+
const richImgFile = await getInputImageFile(richTextImg);
|
53 |
+
const plainImgFile = await getInputImageFile(plainTextImg);
|
54 |
+
const richImgURL = await uploadFile(richImgFile);
|
55 |
+
const plainImgURL = await uploadFile(plainImgFile);
|
56 |
+
|
57 |
+
const descriptionMd = `
|
58 |
+
### Plain Prompt
|
59 |
+
${plain_prompt}
|
60 |
+
|
61 |
+
π Shareable Link + Params: [here](${linkSrc})
|
62 |
+
|
63 |
+
### Rich Tech Image
|
64 |
+
<img src="${richImgURL}">
|
65 |
+
|
66 |
+
### Plain Text Image
|
67 |
+
<img src="${plainImgURL}">
|
68 |
+
|
69 |
+
`;
|
70 |
+
const params = new URLSearchParams({
|
71 |
+
title: titleTxt,
|
72 |
+
description: descriptionMd,
|
73 |
+
});
|
74 |
+
const paramsStr = params.toString();
|
75 |
+
window.open(`https://huggingface.co/spaces/songweig/rich-text-to-image/discussions/new?${paramsStr}`, '_blank');
|
76 |
+
shareBtnEl.style.removeProperty('pointer-events');
|
77 |
+
shareIconEl.style.removeProperty('display');
|
78 |
+
loadingIconEl.style.display = 'none';
|
79 |
+
}"""
|
80 |
+
|
81 |
+
css = """
|
82 |
+
#share-btn-container {
|
83 |
+
display: flex;
|
84 |
+
padding-left: 0.5rem !important;
|
85 |
+
padding-right: 0.5rem !important;
|
86 |
+
background-color: #000000;
|
87 |
+
justify-content: center;
|
88 |
+
align-items: center;
|
89 |
+
border-radius: 9999px !important;
|
90 |
+
width: 13rem;
|
91 |
+
margin-top: 10px;
|
92 |
+
margin-left: auto;
|
93 |
+
flex: unset !important;
|
94 |
+
}
|
95 |
+
#share-btn {
|
96 |
+
all: initial;
|
97 |
+
color: #ffffff;
|
98 |
+
font-weight: 600;
|
99 |
+
cursor: pointer;
|
100 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
101 |
+
margin-left: 0.5rem !important;
|
102 |
+
padding-top: 0.25rem !important;
|
103 |
+
padding-bottom: 0.25rem !important;
|
104 |
+
right:0;
|
105 |
+
}
|
106 |
+
#share-btn * {
|
107 |
+
all: unset !important;
|
108 |
+
}
|
109 |
+
#share-btn-container div:nth-child(-n+2){
|
110 |
+
width: auto !important;
|
111 |
+
min-height: 0px !important;
|
112 |
+
}
|
113 |
+
#share-btn-container .wrap {
|
114 |
+
display: none !important;
|
115 |
+
}
|
116 |
+
"""
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/attention_utils.py
ADDED
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import matplotlib as mpl
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from utils.richtext_utils import seed_everything
|
10 |
+
from sklearn.cluster import KMeans, SpectralClustering
|
11 |
+
|
12 |
+
# SelfAttentionLayers = [
|
13 |
+
# # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
|
14 |
+
# # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
|
15 |
+
# 'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
|
16 |
+
# # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
|
17 |
+
# 'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
|
18 |
+
# 'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
|
19 |
+
# 'mid_block.attentions.0.transformer_blocks.0.attn1',
|
20 |
+
# 'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
|
21 |
+
# 'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
|
22 |
+
# 'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
|
23 |
+
# # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
|
24 |
+
# 'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
|
25 |
+
# # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
|
26 |
+
# # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
|
27 |
+
# # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
|
28 |
+
# # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
|
29 |
+
# ]
|
30 |
+
|
31 |
+
SelfAttentionLayers = [
|
32 |
+
# 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
|
33 |
+
# 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
|
34 |
+
'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
|
35 |
+
# 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
|
36 |
+
'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
|
37 |
+
'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
|
38 |
+
'mid_block.attentions.0.transformer_blocks.0.attn1',
|
39 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
|
40 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
|
41 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
|
42 |
+
# 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
|
43 |
+
'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
|
44 |
+
# 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
|
45 |
+
# 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
|
46 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
|
47 |
+
# 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
CrossAttentionLayers = [
|
52 |
+
# 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
|
53 |
+
# 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
|
54 |
+
'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
|
55 |
+
# 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
|
56 |
+
'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
|
57 |
+
'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
|
58 |
+
'mid_block.attentions.0.transformer_blocks.0.attn2',
|
59 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
|
60 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
|
61 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
|
62 |
+
# 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
|
63 |
+
'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
|
64 |
+
# 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
|
65 |
+
# 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
|
66 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
|
67 |
+
# 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
|
68 |
+
]
|
69 |
+
|
70 |
+
# CrossAttentionLayers = [
|
71 |
+
# 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
|
72 |
+
# 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
|
73 |
+
# 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
|
74 |
+
# 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
|
75 |
+
# 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
|
76 |
+
# 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
|
77 |
+
# 'mid_block.attentions.0.transformer_blocks.0.attn2',
|
78 |
+
# 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
|
79 |
+
# 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
|
80 |
+
# 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
|
81 |
+
# 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
|
82 |
+
# 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
|
83 |
+
# 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
|
84 |
+
# 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
|
85 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
|
86 |
+
# 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
|
87 |
+
# ]
|
88 |
+
|
89 |
+
# CrossAttentionLayers_XL = [
|
90 |
+
# 'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
|
91 |
+
# 'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
|
92 |
+
# 'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
|
93 |
+
# 'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
|
94 |
+
# 'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
|
95 |
+
# 'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
|
96 |
+
# 'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
|
97 |
+
# ]
|
98 |
+
CrossAttentionLayers_XL = [
|
99 |
+
'down_blocks.2.attentions.1.transformer_blocks.3.attn2',
|
100 |
+
'down_blocks.2.attentions.1.transformer_blocks.4.attn2',
|
101 |
+
'mid_block.attentions.0.transformer_blocks.0.attn2',
|
102 |
+
'mid_block.attentions.0.transformer_blocks.1.attn2',
|
103 |
+
'mid_block.attentions.0.transformer_blocks.2.attn2',
|
104 |
+
'mid_block.attentions.0.transformer_blocks.3.attn2',
|
105 |
+
'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
|
106 |
+
'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
|
107 |
+
'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
|
108 |
+
'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
|
109 |
+
'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
|
110 |
+
'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
|
111 |
+
'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
|
112 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2'
|
113 |
+
]
|
114 |
+
|
115 |
+
def split_attention_maps_over_steps(attention_maps):
|
116 |
+
r"""Function for splitting attention maps over steps.
|
117 |
+
Args:
|
118 |
+
attention_maps (dict): Dictionary of attention maps.
|
119 |
+
sampler_order (int): Order of the sampler.
|
120 |
+
"""
|
121 |
+
# This function splits attention maps into unconditional and conditional score and over steps
|
122 |
+
|
123 |
+
attention_maps_cond = dict() # Maps corresponding to conditional score
|
124 |
+
attention_maps_uncond = dict() # Maps corresponding to unconditional score
|
125 |
+
|
126 |
+
for layer in attention_maps.keys():
|
127 |
+
|
128 |
+
for step_num in range(len(attention_maps[layer])):
|
129 |
+
if step_num not in attention_maps_cond:
|
130 |
+
attention_maps_cond[step_num] = dict()
|
131 |
+
attention_maps_uncond[step_num] = dict()
|
132 |
+
|
133 |
+
attention_maps_uncond[step_num].update(
|
134 |
+
{layer: attention_maps[layer][step_num][:1]})
|
135 |
+
attention_maps_cond[step_num].update(
|
136 |
+
{layer: attention_maps[layer][step_num][1:2]})
|
137 |
+
|
138 |
+
return attention_maps_cond, attention_maps_uncond
|
139 |
+
|
140 |
+
|
141 |
+
def save_attention_heatmaps(attention_maps, tokens_vis, save_dir, prefix):
|
142 |
+
r"""Function to plot heatmaps for attention maps.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
attention_maps (dict): Dictionary of attention maps per layer
|
146 |
+
save_dir (str): Directory to save attention maps
|
147 |
+
prefix (str): Filename prefix for html files
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
Heatmaps, one per sample.
|
151 |
+
"""
|
152 |
+
|
153 |
+
html_names = []
|
154 |
+
|
155 |
+
idx = 0
|
156 |
+
html_list = []
|
157 |
+
|
158 |
+
for layer in attention_maps.keys():
|
159 |
+
if idx == 0:
|
160 |
+
# import ipdb;ipdb.set_trace()
|
161 |
+
# create a set of html files.
|
162 |
+
|
163 |
+
batch_size = attention_maps[layer].shape[0]
|
164 |
+
|
165 |
+
for sample_num in range(batch_size):
|
166 |
+
# html path
|
167 |
+
html_rel_path = os.path.join('sample_{}'.format(
|
168 |
+
sample_num), '{}.html'.format(prefix))
|
169 |
+
html_names.append(html_rel_path)
|
170 |
+
html_path = os.path.join(save_dir, html_rel_path)
|
171 |
+
os.makedirs(os.path.dirname(html_path), exist_ok=True)
|
172 |
+
html_list.append(open(html_path, 'wt'))
|
173 |
+
html_list[sample_num].write(
|
174 |
+
'<html><head></head><body><table>\n')
|
175 |
+
|
176 |
+
for sample_num in range(batch_size):
|
177 |
+
|
178 |
+
save_path = os.path.join(save_dir, 'sample_{}'.format(sample_num),
|
179 |
+
prefix, 'layer_{}'.format(layer)) + '.jpg'
|
180 |
+
Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
|
181 |
+
|
182 |
+
layer_name = 'layer_{}'.format(layer)
|
183 |
+
html_list[sample_num].write(
|
184 |
+
f'<tr><td><h1>{layer_name}</h1></td></tr>\n')
|
185 |
+
|
186 |
+
prefix_stem = prefix.split('/')[-1]
|
187 |
+
relative_image_path = os.path.join(
|
188 |
+
prefix_stem, 'layer_{}'.format(layer)) + '.jpg'
|
189 |
+
html_list[sample_num].write(
|
190 |
+
f'<tr><td><img src=\"{relative_image_path}\"></td></tr>\n')
|
191 |
+
|
192 |
+
plt.figure()
|
193 |
+
plt.clf()
|
194 |
+
nrows = 2
|
195 |
+
ncols = 7
|
196 |
+
fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
|
197 |
+
|
198 |
+
fig.set_figheight(8)
|
199 |
+
fig.set_figwidth(28.5)
|
200 |
+
|
201 |
+
# axs[0].set_aspect('equal')
|
202 |
+
# axs[1].set_aspect('equal')
|
203 |
+
# axs[2].set_aspect('equal')
|
204 |
+
# axs[3].set_aspect('equal')
|
205 |
+
# axs[4].set_aspect('equal')
|
206 |
+
# axs[5].set_aspect('equal')
|
207 |
+
|
208 |
+
cmap = plt.get_cmap('YlOrRd')
|
209 |
+
|
210 |
+
for rid in range(nrows):
|
211 |
+
for cid in range(ncols):
|
212 |
+
tid = rid*ncols + cid
|
213 |
+
# import ipdb;ipdb.set_trace()
|
214 |
+
attention_map_cur = attention_maps[layer][sample_num, :, :, tid].numpy(
|
215 |
+
)
|
216 |
+
vmax = float(attention_map_cur.max())
|
217 |
+
vmin = float(attention_map_cur.min())
|
218 |
+
sns.heatmap(
|
219 |
+
attention_map_cur, annot=False, cbar=False, ax=axs[rid, cid],
|
220 |
+
cmap=cmap, vmin=vmin, vmax=vmax
|
221 |
+
)
|
222 |
+
axs[rid, cid].set_xlabel(tokens_vis[tid])
|
223 |
+
|
224 |
+
# axs[0].set_xlabel('Self attention')
|
225 |
+
# axs[1].set_xlabel('Temporal attention')
|
226 |
+
# axs[2].set_xlabel('T5 text attention')
|
227 |
+
# axs[3].set_xlabel('CLIP text attention')
|
228 |
+
# axs[4].set_xlabel('CLIP image attention')
|
229 |
+
# axs[5].set_xlabel('Null text token')
|
230 |
+
|
231 |
+
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
232 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
233 |
+
# fig.colorbar(sm, cax=axs[6])
|
234 |
+
|
235 |
+
fig.tight_layout()
|
236 |
+
plt.savefig(save_path, dpi=64)
|
237 |
+
plt.close('all')
|
238 |
+
|
239 |
+
if idx == (len(attention_maps.keys()) - 1):
|
240 |
+
for sample_num in range(batch_size):
|
241 |
+
html_list[sample_num].write('</table></body></html>')
|
242 |
+
html_list[sample_num].close()
|
243 |
+
|
244 |
+
idx += 1
|
245 |
+
|
246 |
+
return html_names
|
247 |
+
|
248 |
+
|
249 |
+
def create_recursive_html_link(html_path, save_dir):
|
250 |
+
r"""Function for creating recursive html links.
|
251 |
+
If the path is dir1/dir2/dir3/*.html,
|
252 |
+
we create chained directories
|
253 |
+
-dir1
|
254 |
+
dir1.html (has links to all children)
|
255 |
+
-dir2
|
256 |
+
dir2.html (has links to all children)
|
257 |
+
-dir3
|
258 |
+
dir3.html
|
259 |
+
|
260 |
+
Args:
|
261 |
+
html_path (str): Path to html file.
|
262 |
+
save_dir (str): Save directory.
|
263 |
+
"""
|
264 |
+
|
265 |
+
html_path_split = os.path.splitext(html_path)[0].split('/')
|
266 |
+
if len(html_path_split) == 1:
|
267 |
+
return
|
268 |
+
|
269 |
+
# First create the root directory
|
270 |
+
root_dir = html_path_split[0]
|
271 |
+
child_dir = html_path_split[1]
|
272 |
+
|
273 |
+
cur_html_path = os.path.join(save_dir, '{}.html'.format(root_dir))
|
274 |
+
if os.path.exists(cur_html_path):
|
275 |
+
|
276 |
+
fp = open(cur_html_path, 'r')
|
277 |
+
lines_written = fp.readlines()
|
278 |
+
fp.close()
|
279 |
+
|
280 |
+
fp = open(cur_html_path, 'a+')
|
281 |
+
child_path = os.path.join(root_dir, f'{child_dir}.html')
|
282 |
+
line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
|
283 |
+
|
284 |
+
if line_to_write not in lines_written:
|
285 |
+
fp.write('<html><head></head><body><table>\n')
|
286 |
+
fp.write(line_to_write)
|
287 |
+
fp.write('</table></body></html>')
|
288 |
+
fp.close()
|
289 |
+
|
290 |
+
else:
|
291 |
+
|
292 |
+
fp = open(cur_html_path, 'w')
|
293 |
+
|
294 |
+
child_path = os.path.join(root_dir, f'{child_dir}.html')
|
295 |
+
line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
|
296 |
+
|
297 |
+
fp.write('<html><head></head><body><table>\n')
|
298 |
+
fp.write(line_to_write)
|
299 |
+
fp.write('</table></body></html>')
|
300 |
+
|
301 |
+
fp.close()
|
302 |
+
|
303 |
+
child_path = '/'.join(html_path.split('/')[1:])
|
304 |
+
save_dir = os.path.join(save_dir, root_dir)
|
305 |
+
create_recursive_html_link(child_path, save_dir)
|
306 |
+
|
307 |
+
|
308 |
+
def visualize_attention_maps(attention_maps_all, save_dir, width, height, tokens_vis):
|
309 |
+
r"""Function to visualize attention maps.
|
310 |
+
Args:
|
311 |
+
save_dir (str): Path to save attention maps
|
312 |
+
batch_size (int): Batch size
|
313 |
+
sampler_order (int): Sampler order
|
314 |
+
"""
|
315 |
+
|
316 |
+
rand_name = list(attention_maps_all.keys())[0]
|
317 |
+
nsteps = len(attention_maps_all[rand_name])
|
318 |
+
hw_ori = width * height
|
319 |
+
|
320 |
+
# html_path = save_dir + '.html'
|
321 |
+
text_input = save_dir.split('/')[-1]
|
322 |
+
# f = open(html_path, 'wt')
|
323 |
+
|
324 |
+
all_html_paths = []
|
325 |
+
|
326 |
+
for step_num in range(0, nsteps, 5):
|
327 |
+
|
328 |
+
# if cond_id == 'cond':
|
329 |
+
# attention_maps_cur = attention_maps_cond[step_num]
|
330 |
+
# else:
|
331 |
+
# attention_maps_cur = attention_maps_uncond[step_num]
|
332 |
+
|
333 |
+
attention_maps = dict()
|
334 |
+
|
335 |
+
for layer in attention_maps_all.keys():
|
336 |
+
|
337 |
+
attention_ind = attention_maps_all[layer][step_num].cpu()
|
338 |
+
|
339 |
+
# Attention maps are of shape [batch_size, nkeys, 77]
|
340 |
+
# since they are averaged out while collecting from hooks to save memory.
|
341 |
+
# Now split the heads from batch dimension
|
342 |
+
bs, hw, nclip = attention_ind.shape
|
343 |
+
down_ratio = np.sqrt(hw_ori // hw)
|
344 |
+
width_cur = int(width // down_ratio)
|
345 |
+
height_cur = int(height // down_ratio)
|
346 |
+
attention_ind = attention_ind.reshape(
|
347 |
+
bs, height_cur, width_cur, nclip)
|
348 |
+
|
349 |
+
attention_maps[layer] = attention_ind
|
350 |
+
|
351 |
+
# Obtain heatmaps corresponding to random heads and individual heads
|
352 |
+
|
353 |
+
html_names = save_attention_heatmaps(
|
354 |
+
attention_maps, tokens_vis, save_dir=save_dir, prefix='step_{}/attention_maps_cond'.format(
|
355 |
+
step_num)
|
356 |
+
)
|
357 |
+
|
358 |
+
# Write the logic for recursively creating pages
|
359 |
+
for html_name_cur in html_names:
|
360 |
+
all_html_paths.append(os.path.join(text_input, html_name_cur))
|
361 |
+
|
362 |
+
save_dir_root = '/'.join(save_dir.split('/')[0:-1])
|
363 |
+
for html_pth in all_html_paths:
|
364 |
+
create_recursive_html_link(html_pth, save_dir_root)
|
365 |
+
|
366 |
+
|
367 |
+
def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
|
368 |
+
for i, attn_map in enumerate(atten_map_list):
|
369 |
+
n_obj = len(attn_map)
|
370 |
+
plt.figure()
|
371 |
+
plt.clf()
|
372 |
+
|
373 |
+
fig, axs = plt.subplots(
|
374 |
+
ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))
|
375 |
+
|
376 |
+
fig.set_figheight(3)
|
377 |
+
fig.set_figwidth(3*n_obj+0.1)
|
378 |
+
|
379 |
+
cmap = plt.get_cmap('YlOrRd')
|
380 |
+
|
381 |
+
vmax = 0
|
382 |
+
vmin = 1
|
383 |
+
for tid in range(n_obj):
|
384 |
+
attention_map_cur = attn_map[tid]
|
385 |
+
vmax = max(vmax, float(attention_map_cur.max()))
|
386 |
+
vmin = min(vmin, float(attention_map_cur.min()))
|
387 |
+
|
388 |
+
for tid in range(n_obj):
|
389 |
+
sns.heatmap(
|
390 |
+
attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
|
391 |
+
cmap=cmap, vmin=vmin, vmax=vmax
|
392 |
+
)
|
393 |
+
axs[tid].set_axis_off()
|
394 |
+
|
395 |
+
if tokens_vis is not None:
|
396 |
+
if tid == n_obj-1:
|
397 |
+
axs_xlabel = 'other tokens'
|
398 |
+
else:
|
399 |
+
axs_xlabel = ''
|
400 |
+
for token_id in obj_tokens[tid]:
|
401 |
+
axs_xlabel += ' ' + tokens_vis[token_id.item() -
|
402 |
+
1][:-len('</w>')]
|
403 |
+
axs[tid].set_title(axs_xlabel)
|
404 |
+
|
405 |
+
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
406 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
407 |
+
fig.colorbar(sm, cax=axs[-1])
|
408 |
+
|
409 |
+
fig.tight_layout()
|
410 |
+
|
411 |
+
canvas = fig.canvas
|
412 |
+
canvas.draw()
|
413 |
+
width, height = canvas.get_width_height()
|
414 |
+
img = np.frombuffer(canvas.tostring_rgb(),
|
415 |
+
dtype='uint8').reshape((height, width, 3))
|
416 |
+
plt.savefig(os.path.join(
|
417 |
+
save_dir, 'average_seed%d_attn%d.jpg' % (seed, i)), dpi=100)
|
418 |
+
plt.close('all')
|
419 |
+
return img
|
420 |
+
|
421 |
+
|
422 |
+
def get_average_attention_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
|
423 |
+
preprocess=False):
|
424 |
+
r"""Function to visualize attention maps.
|
425 |
+
Args:
|
426 |
+
save_dir (str): Path to save attention maps
|
427 |
+
batch_size (int): Batch size
|
428 |
+
sampler_order (int): Sampler order
|
429 |
+
"""
|
430 |
+
|
431 |
+
# Split attention maps over steps
|
432 |
+
attention_maps_cond, _ = split_attention_maps_over_steps(
|
433 |
+
attention_maps
|
434 |
+
)
|
435 |
+
|
436 |
+
nsteps = len(attention_maps_cond)
|
437 |
+
hw_ori = width * height
|
438 |
+
|
439 |
+
attention_maps = []
|
440 |
+
for obj_token in obj_tokens:
|
441 |
+
attention_maps.append([])
|
442 |
+
|
443 |
+
for step_num in range(nsteps):
|
444 |
+
attention_maps_cur = attention_maps_cond[step_num]
|
445 |
+
|
446 |
+
for layer in attention_maps_cur.keys():
|
447 |
+
if step_num < 10 or layer not in CrossAttentionLayers:
|
448 |
+
continue
|
449 |
+
|
450 |
+
attention_ind = attention_maps_cur[layer].cpu()
|
451 |
+
|
452 |
+
# Attention maps are of shape [batch_size, nkeys, 77]
|
453 |
+
# since they are averaged out while collecting from hooks to save memory.
|
454 |
+
# Now split the heads from batch dimension
|
455 |
+
bs, hw, nclip = attention_ind.shape
|
456 |
+
down_ratio = np.sqrt(hw_ori // hw)
|
457 |
+
width_cur = int(width // down_ratio)
|
458 |
+
height_cur = int(height // down_ratio)
|
459 |
+
attention_ind = attention_ind.reshape(
|
460 |
+
bs, height_cur, width_cur, nclip)
|
461 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
462 |
+
if obj_token[0] == -1:
|
463 |
+
attention_map_prev = torch.stack(
|
464 |
+
[attention_maps[i][-1] for i in range(obj_id)]).sum(0)
|
465 |
+
attention_maps[obj_id].append(
|
466 |
+
attention_map_prev.max()-attention_map_prev)
|
467 |
+
else:
|
468 |
+
obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
|
469 |
+
0].permute([3, 0, 1, 2])
|
470 |
+
# obj_attention_map = attention_ind[:, :, :, obj_token].mean(-1, True).permute([3, 0, 1, 2])
|
471 |
+
obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
|
472 |
+
interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
|
473 |
+
attention_maps[obj_id].append(obj_attention_map)
|
474 |
+
|
475 |
+
attention_maps_averaged = []
|
476 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
477 |
+
if obj_id == len(obj_tokens) - 1:
|
478 |
+
attention_maps_averaged.append(
|
479 |
+
torch.cat(attention_maps[obj_id]).mean(0))
|
480 |
+
else:
|
481 |
+
attention_maps_averaged.append(
|
482 |
+
torch.cat(attention_maps[obj_id]).mean(0))
|
483 |
+
|
484 |
+
attention_maps_averaged_normalized = []
|
485 |
+
attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
|
486 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
487 |
+
attention_maps_averaged_normalized.append(
|
488 |
+
attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
|
489 |
+
|
490 |
+
if obj_tokens[-1][0] != -1:
|
491 |
+
attention_maps_averaged_normalized = (
|
492 |
+
torch.cat(attention_maps_averaged)/0.001).softmax(0)
|
493 |
+
attention_maps_averaged_normalized = [
|
494 |
+
attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
|
495 |
+
|
496 |
+
if preprocess:
|
497 |
+
selem = square(5)
|
498 |
+
selem = square(3)
|
499 |
+
selem = square(1)
|
500 |
+
attention_maps_averaged_eroded = [erosion(skimage.img_as_float(
|
501 |
+
map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]]
|
502 |
+
attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze(
|
503 |
+
0)/255. > 0.8).float() for map in attention_maps_averaged_eroded]
|
504 |
+
attention_maps_averaged_eroded.append(
|
505 |
+
1 - torch.cat(attention_maps_averaged_eroded).sum(0, True))
|
506 |
+
plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized,
|
507 |
+
attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis)
|
508 |
+
attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat(
|
509 |
+
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded]
|
510 |
+
return attention_maps_averaged_eroded
|
511 |
+
else:
|
512 |
+
plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
|
513 |
+
obj_tokens, save_dir, seed, tokens_vis)
|
514 |
+
attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
|
515 |
+
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
|
516 |
+
return attention_maps_averaged_normalized
|
517 |
+
|
518 |
+
|
519 |
+
def get_average_attention_maps_threshold(attention_maps, save_dir, width, height, obj_tokens, seed=0, threshold=0.02):
|
520 |
+
r"""Function to visualize attention maps.
|
521 |
+
Args:
|
522 |
+
save_dir (str): Path to save attention maps
|
523 |
+
batch_size (int): Batch size
|
524 |
+
sampler_order (int): Sampler order
|
525 |
+
"""
|
526 |
+
|
527 |
+
_EPS = 1e-8
|
528 |
+
# Split attention maps over steps
|
529 |
+
attention_maps_cond, _ = split_attention_maps_over_steps(
|
530 |
+
attention_maps
|
531 |
+
)
|
532 |
+
|
533 |
+
nsteps = len(attention_maps_cond)
|
534 |
+
hw_ori = width * height
|
535 |
+
|
536 |
+
attention_maps = []
|
537 |
+
for obj_token in obj_tokens:
|
538 |
+
attention_maps.append([])
|
539 |
+
|
540 |
+
# for each side prompt, get attention maps for all steps and all layers
|
541 |
+
for step_num in range(nsteps):
|
542 |
+
attention_maps_cur = attention_maps_cond[step_num]
|
543 |
+
for layer in attention_maps_cur.keys():
|
544 |
+
attention_ind = attention_maps_cur[layer].cpu()
|
545 |
+
bs, hw, nclip = attention_ind.shape
|
546 |
+
down_ratio = np.sqrt(hw_ori // hw)
|
547 |
+
width_cur = int(width // down_ratio)
|
548 |
+
height_cur = int(height // down_ratio)
|
549 |
+
attention_ind = attention_ind.reshape(
|
550 |
+
bs, height_cur, width_cur, nclip)
|
551 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
552 |
+
if attention_ind.shape[1] > width//2:
|
553 |
+
continue
|
554 |
+
if obj_token[0] != -1:
|
555 |
+
obj_attention_map = attention_ind[:, :, :,
|
556 |
+
obj_token].mean(-1, True).permute([3, 0, 1, 2])
|
557 |
+
obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
|
558 |
+
interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
|
559 |
+
attention_maps[obj_id].append(obj_attention_map)
|
560 |
+
|
561 |
+
# average of all steps and layers, thresholding
|
562 |
+
attention_maps_thres = []
|
563 |
+
attention_maps_averaged = []
|
564 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
565 |
+
if obj_token[0] != -1:
|
566 |
+
average_map = torch.cat(attention_maps[obj_id]).mean(0)
|
567 |
+
attention_maps_averaged.append(average_map)
|
568 |
+
attention_maps_thres.append((average_map > threshold).float())
|
569 |
+
|
570 |
+
# get the remaining region except for the original prompt
|
571 |
+
attention_maps_averaged_normalized = []
|
572 |
+
attention_maps_averaged_sum = torch.cat(attention_maps_thres).sum(0) + _EPS
|
573 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
574 |
+
if obj_token[0] != -1:
|
575 |
+
attention_maps_averaged_normalized.append(
|
576 |
+
attention_maps_thres[obj_id]/attention_maps_averaged_sum)
|
577 |
+
else:
|
578 |
+
attention_map_prev = torch.stack(
|
579 |
+
attention_maps_averaged_normalized).sum(0)
|
580 |
+
attention_maps_averaged_normalized.append(1.-attention_map_prev)
|
581 |
+
|
582 |
+
plot_attention_maps(
|
583 |
+
[attention_maps_averaged, attention_maps_averaged_normalized], save_dir, seed)
|
584 |
+
|
585 |
+
attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
|
586 |
+
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
|
587 |
+
# attention_maps_averaged_normalized = attention_maps_averaged_normalized.unsqueeze(1).repeat([1, 4, 1, 1]).cuda()
|
588 |
+
return attention_maps_averaged_normalized
|
589 |
+
|
590 |
+
|
591 |
+
def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None,
|
592 |
+
preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
|
593 |
+
r"""Function to visualize attention maps.
|
594 |
+
Args:
|
595 |
+
save_dir (str): Path to save attention maps
|
596 |
+
batch_size (int): Batch size
|
597 |
+
sampler_order (int): Sampler order
|
598 |
+
"""
|
599 |
+
|
600 |
+
resolution = 32
|
601 |
+
# attn_maps_1024 = [attn_map for attn_map in selfattn_maps.values(
|
602 |
+
# ) if attn_map.shape[1] == resolution**2]
|
603 |
+
# attn_maps_1024 = torch.cat(attn_maps_1024).mean(0).cpu().numpy()
|
604 |
+
attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
|
605 |
+
for attn_map in selfattn_maps.values():
|
606 |
+
resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
|
607 |
+
if resolution_map != resolution:
|
608 |
+
continue
|
609 |
+
# attn_map = torch.nn.functional.interpolate(rearrange(attn_map, '1 c (h w) -> 1 c h w', h=resolution_map), (resolution, resolution),
|
610 |
+
# mode='bicubic', antialias=True)
|
611 |
+
# attn_map = rearrange(attn_map, '1 (h w) a b -> 1 (a b) h w', h=resolution_map)
|
612 |
+
attn_map = attn_map.reshape(
|
613 |
+
1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2]).float()
|
614 |
+
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
|
615 |
+
mode='bicubic', antialias=True)
|
616 |
+
attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
|
617 |
+
1, resolution**2, resolution_map**2))
|
618 |
+
attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
|
619 |
+
for v in attn_maps_1024.values() if len(v) > 0], -1).numpy()
|
620 |
+
if save_attn:
|
621 |
+
print('saving self-attention maps...', attn_maps_1024.shape)
|
622 |
+
torch.save(torch.from_numpy(attn_maps_1024),
|
623 |
+
'results/maps/selfattn_maps.pth')
|
624 |
+
seed_everything(kmeans_seed)
|
625 |
+
# import ipdb;ipdb.set_trace()
|
626 |
+
# kmeans = KMeans(n_clusters=num_segments,
|
627 |
+
# n_init=10).fit(attn_maps_1024)
|
628 |
+
# clusters = kmeans.labels_
|
629 |
+
# clusters = clusters.reshape(resolution, resolution)
|
630 |
+
# mesh = np.array(np.meshgrid(range(resolution), range(resolution), indexing='ij'), dtype=np.float32)/resolution
|
631 |
+
# dists = mesh.reshape(2, -1).T
|
632 |
+
# delta = 0.01
|
633 |
+
# spatial_sim = rbf_kernel(dists, dists)*delta
|
634 |
+
sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
|
635 |
+
assign_labels='kmeans')
|
636 |
+
clusters = sc.fit_predict(attn_maps_1024)
|
637 |
+
clusters = clusters.reshape(resolution, resolution)
|
638 |
+
fig = plt.figure()
|
639 |
+
plt.imshow(clusters)
|
640 |
+
plt.axis('off')
|
641 |
+
plt.savefig(os.path.join(save_dir, 'segmentation_k%d_seed%d.jpg' % (num_segments, kmeans_seed)),
|
642 |
+
bbox_inches='tight', pad_inches=0)
|
643 |
+
if return_vis:
|
644 |
+
canvas = fig.canvas
|
645 |
+
canvas.draw()
|
646 |
+
cav_width, cav_height = canvas.get_width_height()
|
647 |
+
segments_vis = np.frombuffer(canvas.tostring_rgb(),
|
648 |
+
dtype='uint8').reshape((cav_height, cav_width, 3))
|
649 |
+
|
650 |
+
plt.close()
|
651 |
+
|
652 |
+
# label the segmentation mask using cross-attention maps
|
653 |
+
cross_attn_maps_1024 = []
|
654 |
+
for attn_map in crossattn_maps.values():
|
655 |
+
resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
|
656 |
+
# if resolution_map != 16:
|
657 |
+
# continue
|
658 |
+
attn_map = attn_map.reshape(
|
659 |
+
1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2]).float()
|
660 |
+
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
|
661 |
+
mode='bicubic', antialias=True)
|
662 |
+
cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
|
663 |
+
|
664 |
+
cross_attn_maps_1024 = torch.cat(
|
665 |
+
cross_attn_maps_1024).mean(0).cpu().numpy()
|
666 |
+
normalized_span_maps = []
|
667 |
+
for token_ids in obj_tokens:
|
668 |
+
token_ids = torch.clip(token_ids, 0, 76)
|
669 |
+
span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
|
670 |
+
normalized_span_map = np.zeros_like(span_token_maps)
|
671 |
+
for i in range(span_token_maps.shape[-1]):
|
672 |
+
curr_noun_map = span_token_maps[:, :, i]
|
673 |
+
normalized_span_map[:, :, i] = (
|
674 |
+
# curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
|
675 |
+
curr_noun_map - np.abs(curr_noun_map.min())) / (curr_noun_map.max()-curr_noun_map.min())
|
676 |
+
normalized_span_maps.append(normalized_span_map)
|
677 |
+
foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
|
678 |
+
) for normalized_span_map in normalized_span_maps]
|
679 |
+
background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze()
|
680 |
+
for c in range(num_segments):
|
681 |
+
cluster_mask = np.zeros_like(clusters)
|
682 |
+
cluster_mask[clusters == c] = 1.
|
683 |
+
is_foreground = False
|
684 |
+
for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens):
|
685 |
+
score_maps = [cluster_mask * normalized_span_map[:, :, i]
|
686 |
+
for i in range(len(token_ids))]
|
687 |
+
scores = [score_map.sum() / cluster_mask.sum()
|
688 |
+
for score_map in score_maps]
|
689 |
+
if max(scores) > segment_threshold:
|
690 |
+
foreground_nouns_map += cluster_mask
|
691 |
+
is_foreground = True
|
692 |
+
if not is_foreground:
|
693 |
+
background_map += cluster_mask
|
694 |
+
foreground_token_maps.append(background_map)
|
695 |
+
|
696 |
+
# resize the token maps and visualization
|
697 |
+
resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze(
|
698 |
+
0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1)
|
699 |
+
|
700 |
+
resized_token_maps = resized_token_maps / \
|
701 |
+
(resized_token_maps.sum(0, True)+1e-8)
|
702 |
+
resized_token_maps = [token_map.unsqueeze(
|
703 |
+
0) for token_map in resized_token_maps]
|
704 |
+
foreground_token_maps = [token_map[None, :, :]
|
705 |
+
for token_map in foreground_token_maps]
|
706 |
+
if preprocess:
|
707 |
+
selem = square(5)
|
708 |
+
eroded_token_maps = torch.stack([torch.from_numpy(erosion(skimage.img_as_float(
|
709 |
+
map[0].numpy()*255), selem))/255. for map in resized_token_maps[:-1]]).clamp(0, 1)
|
710 |
+
# import ipdb; ipdb.set_trace()
|
711 |
+
eroded_background_maps = (1-eroded_token_maps.sum(0, True)).clamp(0, 1)
|
712 |
+
eroded_token_maps = torch.cat([eroded_token_maps, eroded_background_maps])
|
713 |
+
eroded_token_maps = eroded_token_maps / (eroded_token_maps.sum(0, True)+1e-8)
|
714 |
+
resized_token_maps = [token_map.unsqueeze(
|
715 |
+
0) for token_map in eroded_token_maps]
|
716 |
+
|
717 |
+
token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
|
718 |
+
save_dir, kmeans_seed, tokens_vis)
|
719 |
+
resized_token_maps = [token_map.unsqueeze(1).repeat(
|
720 |
+
[1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
|
721 |
+
if return_vis:
|
722 |
+
return resized_token_maps, segments_vis, token_maps_vis
|
723 |
+
else:
|
724 |
+
return resized_token_maps
|
utils/richtext_utils.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
COLORS = {
|
8 |
+
'brown': [165, 42, 42],
|
9 |
+
'red': [255, 0, 0],
|
10 |
+
'pink': [253, 108, 158],
|
11 |
+
'orange': [255, 165, 0],
|
12 |
+
'yellow': [255, 255, 0],
|
13 |
+
'purple': [128, 0, 128],
|
14 |
+
'green': [0, 128, 0],
|
15 |
+
'blue': [0, 0, 255],
|
16 |
+
'white': [255, 255, 255],
|
17 |
+
'gray': [128, 128, 128],
|
18 |
+
'black': [0, 0, 0],
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def seed_everything(seed):
|
23 |
+
random.seed(seed)
|
24 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed(seed)
|
28 |
+
|
29 |
+
|
30 |
+
def hex_to_rgb(hex_string, return_nearest_color=False):
|
31 |
+
r"""
|
32 |
+
Covert Hex triplet to RGB triplet.
|
33 |
+
"""
|
34 |
+
# Remove '#' symbol if present
|
35 |
+
hex_string = hex_string.lstrip('#')
|
36 |
+
# Convert hex values to integers
|
37 |
+
red = int(hex_string[0:2], 16)
|
38 |
+
green = int(hex_string[2:4], 16)
|
39 |
+
blue = int(hex_string[4:6], 16)
|
40 |
+
rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
|
41 |
+
if return_nearest_color:
|
42 |
+
nearest_color = find_nearest_color(rgb)
|
43 |
+
return rgb.cuda(), nearest_color
|
44 |
+
return rgb.cuda()
|
45 |
+
|
46 |
+
|
47 |
+
def find_nearest_color(rgb):
|
48 |
+
r"""
|
49 |
+
Find the nearest neighbor color given the RGB value.
|
50 |
+
"""
|
51 |
+
if isinstance(rgb, list) or isinstance(rgb, tuple):
|
52 |
+
rgb = torch.FloatTensor(rgb)[None, :, None, None]/255.
|
53 |
+
color_distance = torch.FloatTensor([np.linalg.norm(
|
54 |
+
rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()])
|
55 |
+
nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()]
|
56 |
+
return nearest_color
|
57 |
+
|
58 |
+
|
59 |
+
def font2style(font):
|
60 |
+
r"""
|
61 |
+
Convert the font name to the style name.
|
62 |
+
"""
|
63 |
+
return {'mirza': 'Claud Monet, impressionism, oil on canvas',
|
64 |
+
'roboto': 'Ukiyoe',
|
65 |
+
'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq',
|
66 |
+
'sofia': 'Pop Art, masterpiece, andy warhol',
|
67 |
+
'slabo': 'Vincent Van Gogh',
|
68 |
+
'inconsolata': 'Pixel Art, 8 bits, 16 bits',
|
69 |
+
'ubuntu': 'Rembrandt',
|
70 |
+
'Monoton': 'neon art, colorful light, highly details, octane render',
|
71 |
+
'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
|
72 |
+
|
73 |
+
|
74 |
+
def parse_json(json_str):
|
75 |
+
r"""
|
76 |
+
Convert the JSON string to attributes.
|
77 |
+
"""
|
78 |
+
# initialze region-base attributes.
|
79 |
+
base_text_prompt = ''
|
80 |
+
style_text_prompts = []
|
81 |
+
footnote_text_prompts = []
|
82 |
+
footnote_target_tokens = []
|
83 |
+
color_text_prompts = []
|
84 |
+
color_rgbs = []
|
85 |
+
color_names = []
|
86 |
+
size_text_prompts_and_sizes = []
|
87 |
+
|
88 |
+
# parse the attributes from JSON.
|
89 |
+
prev_style = None
|
90 |
+
prev_color_rgb = None
|
91 |
+
use_grad_guidance = False
|
92 |
+
for span in json_str['ops']:
|
93 |
+
text_prompt = span['insert'].rstrip('\n')
|
94 |
+
base_text_prompt += span['insert'].rstrip('\n')
|
95 |
+
if text_prompt == ' ':
|
96 |
+
continue
|
97 |
+
if 'attributes' in span:
|
98 |
+
if 'font' in span['attributes']:
|
99 |
+
style = font2style(span['attributes']['font'])
|
100 |
+
if prev_style == style:
|
101 |
+
prev_text_prompt = style_text_prompts[-1].split('in the style of')[
|
102 |
+
0]
|
103 |
+
style_text_prompts[-1] = prev_text_prompt + \
|
104 |
+
' ' + text_prompt + f' in the style of {style}'
|
105 |
+
else:
|
106 |
+
style_text_prompts.append(
|
107 |
+
text_prompt + f' in the style of {style}')
|
108 |
+
prev_style = style
|
109 |
+
else:
|
110 |
+
prev_style = None
|
111 |
+
if 'link' in span['attributes']:
|
112 |
+
footnote_text_prompts.append(span['attributes']['link'])
|
113 |
+
footnote_target_tokens.append(text_prompt)
|
114 |
+
font_size = 1
|
115 |
+
if 'size' in span['attributes'] and 'strike' not in span['attributes']:
|
116 |
+
font_size = float(span['attributes']['size'][:-2])/3.
|
117 |
+
elif 'size' in span['attributes'] and 'strike' in span['attributes']:
|
118 |
+
font_size = -float(span['attributes']['size'][:-2])/3.
|
119 |
+
elif 'size' not in span['attributes'] and 'strike' not in span['attributes']:
|
120 |
+
font_size = 1
|
121 |
+
if 'color' in span['attributes']:
|
122 |
+
use_grad_guidance = True
|
123 |
+
color_rgb, nearest_color = hex_to_rgb(
|
124 |
+
span['attributes']['color'], True)
|
125 |
+
if prev_color_rgb == color_rgb:
|
126 |
+
prev_text_prompt = color_text_prompts[-1]
|
127 |
+
color_text_prompts[-1] = prev_text_prompt + \
|
128 |
+
' ' + text_prompt
|
129 |
+
else:
|
130 |
+
color_rgbs.append(color_rgb)
|
131 |
+
color_names.append(nearest_color)
|
132 |
+
color_text_prompts.append(text_prompt)
|
133 |
+
if font_size != 1:
|
134 |
+
size_text_prompts_and_sizes.append([text_prompt, font_size])
|
135 |
+
return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
|
136 |
+
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance
|
137 |
+
|
138 |
+
|
139 |
+
def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts,
|
140 |
+
footnote_target_tokens, color_text_prompts, color_names):
|
141 |
+
r"""
|
142 |
+
Algorithm 1 in the paper.
|
143 |
+
"""
|
144 |
+
region_text_prompts = []
|
145 |
+
region_target_token_ids = []
|
146 |
+
base_tokens = model.tokenizer._tokenize(base_text_prompt)
|
147 |
+
# process the style text prompt
|
148 |
+
for text_prompt in style_text_prompts:
|
149 |
+
region_text_prompts.append(text_prompt)
|
150 |
+
region_target_token_ids.append([])
|
151 |
+
style_tokens = model.tokenizer._tokenize(
|
152 |
+
text_prompt.split('in the style of')[0])
|
153 |
+
for style_token in style_tokens:
|
154 |
+
region_target_token_ids[-1].append(
|
155 |
+
base_tokens.index(style_token)+1)
|
156 |
+
|
157 |
+
# process the complementary text prompt
|
158 |
+
for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens):
|
159 |
+
region_target_token_ids.append([])
|
160 |
+
region_text_prompts.append(footnote_text_prompt)
|
161 |
+
style_tokens = model.tokenizer._tokenize(text_prompt)
|
162 |
+
for style_token in style_tokens:
|
163 |
+
region_target_token_ids[-1].append(
|
164 |
+
base_tokens.index(style_token)+1)
|
165 |
+
|
166 |
+
# process the color text prompt
|
167 |
+
for color_text_prompt, color_name in zip(color_text_prompts, color_names):
|
168 |
+
region_target_token_ids.append([])
|
169 |
+
region_text_prompts.append(color_name+' '+color_text_prompt)
|
170 |
+
style_tokens = model.tokenizer._tokenize(color_text_prompt)
|
171 |
+
for style_token in style_tokens:
|
172 |
+
region_target_token_ids[-1].append(
|
173 |
+
base_tokens.index(style_token)+1)
|
174 |
+
|
175 |
+
# process the remaining tokens without any attributes
|
176 |
+
region_text_prompts.append(base_text_prompt)
|
177 |
+
region_target_token_ids_all = [
|
178 |
+
id for ids in region_target_token_ids for id in ids]
|
179 |
+
target_token_ids_rest = [id for id in range(
|
180 |
+
1, len(base_tokens)+1) if id not in region_target_token_ids_all]
|
181 |
+
region_target_token_ids.append(target_token_ids_rest)
|
182 |
+
|
183 |
+
region_target_token_ids = [torch.LongTensor(
|
184 |
+
obj_token_id) for obj_token_id in region_target_token_ids]
|
185 |
+
return region_text_prompts, region_target_token_ids, base_tokens
|
186 |
+
|
187 |
+
|
188 |
+
def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes):
|
189 |
+
r"""
|
190 |
+
Control the token impact using font sizes.
|
191 |
+
"""
|
192 |
+
word_pos = []
|
193 |
+
font_sizes = []
|
194 |
+
for text_prompt, font_size in size_text_prompts_and_sizes:
|
195 |
+
size_tokens = model.tokenizer._tokenize(text_prompt)
|
196 |
+
for size_token in size_tokens:
|
197 |
+
word_pos.append(base_tokens.index(size_token)+1)
|
198 |
+
font_sizes.append(font_size)
|
199 |
+
if len(word_pos) > 0:
|
200 |
+
word_pos = torch.LongTensor(word_pos).cuda()
|
201 |
+
font_sizes = torch.FloatTensor(font_sizes).cuda()
|
202 |
+
else:
|
203 |
+
word_pos = None
|
204 |
+
font_sizes = None
|
205 |
+
text_format_dict = {
|
206 |
+
'word_pos': word_pos,
|
207 |
+
'font_size': font_sizes,
|
208 |
+
}
|
209 |
+
return text_format_dict
|
210 |
+
|
211 |
+
|
212 |
+
def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict,
|
213 |
+
guidance_start_step=999, color_guidance_weight=1):
|
214 |
+
r"""
|
215 |
+
Control the token impact using font sizes.
|
216 |
+
"""
|
217 |
+
color_target_token_ids = []
|
218 |
+
for text_prompt in color_text_prompts:
|
219 |
+
color_target_token_ids.append([])
|
220 |
+
color_tokens = model.tokenizer._tokenize(text_prompt)
|
221 |
+
for color_token in color_tokens:
|
222 |
+
color_target_token_ids[-1].append(base_tokens.index(color_token)+1)
|
223 |
+
color_target_token_ids_all = [
|
224 |
+
id for ids in color_target_token_ids for id in ids]
|
225 |
+
color_target_token_ids_rest = [id for id in range(
|
226 |
+
1, len(base_tokens)+1) if id not in color_target_token_ids_all]
|
227 |
+
color_target_token_ids.append(color_target_token_ids_rest)
|
228 |
+
color_target_token_ids = [torch.LongTensor(
|
229 |
+
obj_token_id) for obj_token_id in color_target_token_ids]
|
230 |
+
|
231 |
+
text_format_dict['target_RGB'] = color_rgbs
|
232 |
+
text_format_dict['guidance_start_step'] = guidance_start_step
|
233 |
+
text_format_dict['color_guidance_weight'] = color_guidance_weight
|
234 |
+
return text_format_dict, color_target_token_ids
|