Spaces:
Build error
Build error
Duplicate from songweig/rich-text-to-image
Browse filesCo-authored-by: Songwei Ge <[email protected]>
- .gitattributes +34 -0
- .gitignore +3 -0
- README.md +13 -0
- app.py +557 -0
- models/attention.py +904 -0
- models/region_diffusion.py +461 -0
- models/unet_2d_blocks.py +1855 -0
- models/unet_2d_condition.py +411 -0
- requirements.txt +9 -0
- rich-text-to-json-iframe.html +341 -0
- rich-text-to-json.js +349 -0
- share_btn.py +116 -0
- utils/.DS_Store +0 -0
- utils/attention_utils.py +318 -0
- utils/richtext_utils.py +234 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Rich Text To Image
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: songweig/rich-text-to-image
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
from models.region_diffusion import RegionDiffusion
|
12 |
+
from utils.attention_utils import get_token_maps
|
13 |
+
from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
|
14 |
+
get_attention_control_input, get_gradient_guidance_input
|
15 |
+
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from PIL import Image, ImageOps
|
19 |
+
from share_btn import community_icon_html, loading_icon_html, share_js, css
|
20 |
+
|
21 |
+
|
22 |
+
help_text = """
|
23 |
+
If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
|
24 |
+
1. If you format only a portion of a word rather than the complete word, an error may occur.
|
25 |
+
2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
|
26 |
+
3. Consider using a different seed.
|
27 |
+
"""
|
28 |
+
|
29 |
+
|
30 |
+
canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
|
31 |
+
get_js_data = """
|
32 |
+
async (text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, rich_text_input, height, width, steps, guidance_weights) => {
|
33 |
+
const richEl = document.getElementById("rich-text-root");
|
34 |
+
const data = richEl? richEl.contentDocument.body._data : {};
|
35 |
+
return [text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, JSON.stringify(data), height, width, steps, guidance_weights];
|
36 |
+
}
|
37 |
+
"""
|
38 |
+
set_js_data = """
|
39 |
+
async (text_input) => {
|
40 |
+
const richEl = document.getElementById("rich-text-root");
|
41 |
+
const data = text_input ? JSON.parse(text_input) : null;
|
42 |
+
if (richEl && data) richEl.contentDocument.body.setQuillContents(data);
|
43 |
+
}
|
44 |
+
"""
|
45 |
+
|
46 |
+
get_window_url_params = """
|
47 |
+
async (url_params) => {
|
48 |
+
const params = new URLSearchParams(window.location.search);
|
49 |
+
url_params = Object.fromEntries(params);
|
50 |
+
return [url_params];
|
51 |
+
}
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
def load_url_params(url_params):
|
56 |
+
if 'prompt' in url_params:
|
57 |
+
return gr.update(visible=True), url_params
|
58 |
+
else:
|
59 |
+
return gr.update(visible=False), url_params
|
60 |
+
|
61 |
+
|
62 |
+
def main():
|
63 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
64 |
+
model = RegionDiffusion(device)
|
65 |
+
|
66 |
+
def generate(
|
67 |
+
text_input: str,
|
68 |
+
negative_text: str,
|
69 |
+
num_segments: int,
|
70 |
+
segment_threshold: float,
|
71 |
+
inject_interval: float,
|
72 |
+
inject_background: float,
|
73 |
+
seed: int,
|
74 |
+
color_guidance_weight: float,
|
75 |
+
rich_text_input: str,
|
76 |
+
height: int,
|
77 |
+
width: int,
|
78 |
+
steps: int,
|
79 |
+
guidance_weight: float,
|
80 |
+
):
|
81 |
+
run_dir = 'results/'
|
82 |
+
os.makedirs(run_dir, exist_ok=True)
|
83 |
+
# Load region diffusion model.
|
84 |
+
height = int(height) if height else 512
|
85 |
+
width = int(width) if width else 512
|
86 |
+
steps = 41 if not steps else steps
|
87 |
+
guidance_weight = 8.5 if not guidance_weight else guidance_weight
|
88 |
+
text_input = rich_text_input if rich_text_input != '' and rich_text_input != None else text_input
|
89 |
+
print('text_input', text_input, width, height, steps, guidance_weight, num_segments, segment_threshold, inject_interval, inject_background, color_guidance_weight, negative_text)
|
90 |
+
if (text_input == '' or rich_text_input == ''):
|
91 |
+
raise gr.Error("Please enter some text.")
|
92 |
+
# parse json to span attributes
|
93 |
+
base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
|
94 |
+
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
|
95 |
+
json.loads(text_input))
|
96 |
+
|
97 |
+
# create control input for region diffusion
|
98 |
+
region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
|
99 |
+
model, base_text_prompt, style_text_prompts, footnote_text_prompts,
|
100 |
+
footnote_target_tokens, color_text_prompts, color_names)
|
101 |
+
|
102 |
+
# create control input for cross attention
|
103 |
+
text_format_dict = get_attention_control_input(
|
104 |
+
model, base_tokens, size_text_prompts_and_sizes)
|
105 |
+
|
106 |
+
# create control input for region guidance
|
107 |
+
text_format_dict, color_target_token_ids = get_gradient_guidance_input(
|
108 |
+
model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, color_guidance_weight=color_guidance_weight)
|
109 |
+
|
110 |
+
seed_everything(seed)
|
111 |
+
|
112 |
+
# get token maps from plain text to image generation.
|
113 |
+
begin_time = time.time()
|
114 |
+
if model.selfattn_maps is None and model.crossattn_maps is None:
|
115 |
+
model.remove_tokenmap_hooks()
|
116 |
+
model.register_tokenmap_hooks()
|
117 |
+
else:
|
118 |
+
model.reset_attention_maps()
|
119 |
+
model.remove_tokenmap_hooks()
|
120 |
+
plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
|
121 |
+
height=height, width=width, num_inference_steps=steps,
|
122 |
+
guidance_scale=guidance_weight)
|
123 |
+
print('time lapses to get attention maps: %.4f' %
|
124 |
+
(time.time()-begin_time))
|
125 |
+
seed_everything(seed)
|
126 |
+
color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
|
127 |
+
512//8, 512//8, color_target_token_ids[:-1], seed,
|
128 |
+
base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
|
129 |
+
return_vis=True)
|
130 |
+
seed_everything(seed)
|
131 |
+
model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
|
132 |
+
512//8, 512//8, region_target_token_ids[:-1], seed,
|
133 |
+
base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
|
134 |
+
return_vis=True)
|
135 |
+
color_obj_atten_all = torch.zeros_like(color_obj_masks[-1])
|
136 |
+
for obj_mask in color_obj_masks[:-1]:
|
137 |
+
color_obj_atten_all += obj_mask
|
138 |
+
color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
|
139 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
140 |
+
antialias=True)
|
141 |
+
for color_obj_mask in color_obj_masks]
|
142 |
+
text_format_dict['color_obj_atten'] = color_obj_masks
|
143 |
+
text_format_dict['color_obj_atten_all'] = color_obj_atten_all
|
144 |
+
model.remove_tokenmap_hooks()
|
145 |
+
|
146 |
+
# generate image from rich text
|
147 |
+
begin_time = time.time()
|
148 |
+
seed_everything(seed)
|
149 |
+
rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
|
150 |
+
height=height, width=width, num_inference_steps=steps,
|
151 |
+
guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
|
152 |
+
text_format_dict=text_format_dict, inject_selfattn=inject_interval,
|
153 |
+
inject_background=inject_background)
|
154 |
+
print('time lapses to generate image from rich text: %.4f' %
|
155 |
+
(time.time()-begin_time))
|
156 |
+
return [plain_img[0], rich_img[0], segments_vis, token_maps]
|
157 |
+
|
158 |
+
with gr.Blocks(css=css) as demo:
|
159 |
+
url_params = gr.JSON({}, visible=False, label="URL Params")
|
160 |
+
gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
|
161 |
+
<p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
|
162 |
+
<p> UMD, Adobe, CMU <p/>
|
163 |
+
<p> <a href="https://huggingface.co/spaces/songweig/rich-text-to-image?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="display:inline;"alt="Duplicate Space"></a> | <a href="https://rich-text-to-image.github.io">[Website]</a> | <a href="https://github.com/SongweiGe/rich-text-to-image">[Code]</a> | <a href="https://arxiv.org/abs/2304.06720">[Paper]</a><p/>
|
164 |
+
<p> For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""")
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column():
|
167 |
+
rich_text_el = gr.HTML(canvas_html, elem_id="canvas_html")
|
168 |
+
rich_text_input = gr.Textbox(value="", visible=False)
|
169 |
+
text_input = gr.Textbox(
|
170 |
+
label='Rich-text JSON Input',
|
171 |
+
visible=False,
|
172 |
+
max_lines=1,
|
173 |
+
placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'',
|
174 |
+
elem_id="text_input"
|
175 |
+
)
|
176 |
+
negative_prompt = gr.Textbox(
|
177 |
+
label='Negative Prompt',
|
178 |
+
max_lines=1,
|
179 |
+
placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
|
180 |
+
elem_id="negative_prompt"
|
181 |
+
)
|
182 |
+
segment_threshold = gr.Slider(label='Token map threshold',
|
183 |
+
info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
|
184 |
+
minimum=0,
|
185 |
+
maximum=1,
|
186 |
+
step=0.01,
|
187 |
+
value=0.25)
|
188 |
+
inject_interval = gr.Slider(label='Detail preservation',
|
189 |
+
info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
|
190 |
+
minimum=0,
|
191 |
+
maximum=1,
|
192 |
+
step=0.01,
|
193 |
+
value=0.)
|
194 |
+
inject_background = gr.Slider(label='Unformatted token preservation',
|
195 |
+
info='(To affect less the tokens without any rich-text attributes, increase this.)',
|
196 |
+
minimum=0,
|
197 |
+
maximum=1,
|
198 |
+
step=0.01,
|
199 |
+
value=0.3)
|
200 |
+
color_guidance_weight = gr.Slider(label='Color weight',
|
201 |
+
info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
|
202 |
+
minimum=0,
|
203 |
+
maximum=2,
|
204 |
+
step=0.1,
|
205 |
+
value=0.5)
|
206 |
+
num_segments = gr.Slider(label='Number of segments',
|
207 |
+
minimum=2,
|
208 |
+
maximum=20,
|
209 |
+
step=1,
|
210 |
+
value=9)
|
211 |
+
seed = gr.Slider(label='Seed',
|
212 |
+
minimum=0,
|
213 |
+
maximum=100000,
|
214 |
+
step=1,
|
215 |
+
value=6,
|
216 |
+
elem_id="seed"
|
217 |
+
)
|
218 |
+
with gr.Accordion('Other Parameters', open=False):
|
219 |
+
steps = gr.Slider(label='Number of Steps',
|
220 |
+
minimum=0,
|
221 |
+
maximum=500,
|
222 |
+
step=1,
|
223 |
+
value=41)
|
224 |
+
guidance_weight = gr.Slider(label='CFG weight',
|
225 |
+
minimum=0,
|
226 |
+
maximum=50,
|
227 |
+
step=0.1,
|
228 |
+
value=8.5)
|
229 |
+
width = gr.Dropdown(choices=[512],
|
230 |
+
value=512,
|
231 |
+
label='Width',
|
232 |
+
visible=True)
|
233 |
+
height = gr.Dropdown(choices=[512],
|
234 |
+
value=512,
|
235 |
+
label='height',
|
236 |
+
visible=True)
|
237 |
+
|
238 |
+
with gr.Row():
|
239 |
+
with gr.Column(scale=1, min_width=100):
|
240 |
+
generate_button = gr.Button("Generate")
|
241 |
+
load_params_button = gr.Button(
|
242 |
+
"Load from URL Params", visible=True)
|
243 |
+
with gr.Column():
|
244 |
+
richtext_result = gr.Image(
|
245 |
+
label='Rich-text', elem_id="rich-text-image")
|
246 |
+
richtext_result.style(height=512)
|
247 |
+
with gr.Row():
|
248 |
+
plaintext_result = gr.Image(
|
249 |
+
label='Plain-text', elem_id="plain-text-image")
|
250 |
+
segments = gr.Image(label='Segmentation')
|
251 |
+
with gr.Row():
|
252 |
+
token_map = gr.Image(label='Token Maps')
|
253 |
+
with gr.Row(visible=False) as share_row:
|
254 |
+
with gr.Group(elem_id="share-btn-container"):
|
255 |
+
community_icon = gr.HTML(community_icon_html)
|
256 |
+
loading_icon = gr.HTML(loading_icon_html)
|
257 |
+
share_button = gr.Button(
|
258 |
+
"Share to community", elem_id="share-btn")
|
259 |
+
share_button.click(None, [], [], _js=share_js)
|
260 |
+
with gr.Row():
|
261 |
+
gr.Markdown(help_text)
|
262 |
+
|
263 |
+
with gr.Row():
|
264 |
+
footnote_examples = [
|
265 |
+
[
|
266 |
+
'{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
|
267 |
+
'',
|
268 |
+
5,
|
269 |
+
0.3,
|
270 |
+
0,
|
271 |
+
0.5,
|
272 |
+
6,
|
273 |
+
0,
|
274 |
+
None,
|
275 |
+
],
|
276 |
+
[
|
277 |
+
'{"ops":[{"insert":"A "},{"attributes":{"link":"Thor Kitchen 30 Inch Wide Freestanding Gas Range with Automatic Re-Ignition System"},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
|
278 |
+
'',
|
279 |
+
7,
|
280 |
+
0.5,
|
281 |
+
0,
|
282 |
+
0.5,
|
283 |
+
6,
|
284 |
+
0,
|
285 |
+
None,
|
286 |
+
],
|
287 |
+
[
|
288 |
+
'{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
|
289 |
+
'',
|
290 |
+
5,
|
291 |
+
0.3,
|
292 |
+
0,
|
293 |
+
0.1,
|
294 |
+
4,
|
295 |
+
0,
|
296 |
+
None,
|
297 |
+
],
|
298 |
+
]
|
299 |
+
|
300 |
+
gr.Examples(examples=footnote_examples,
|
301 |
+
label='Footnote examples',
|
302 |
+
inputs=[
|
303 |
+
text_input,
|
304 |
+
negative_prompt,
|
305 |
+
num_segments,
|
306 |
+
segment_threshold,
|
307 |
+
inject_interval,
|
308 |
+
inject_background,
|
309 |
+
seed,
|
310 |
+
color_guidance_weight,
|
311 |
+
rich_text_input,
|
312 |
+
],
|
313 |
+
outputs=[
|
314 |
+
plaintext_result,
|
315 |
+
richtext_result,
|
316 |
+
segments,
|
317 |
+
token_map,
|
318 |
+
],
|
319 |
+
fn=generate,
|
320 |
+
cache_examples=True,
|
321 |
+
examples_per_page=20)
|
322 |
+
with gr.Row():
|
323 |
+
color_examples = [
|
324 |
+
[
|
325 |
+
'{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
326 |
+
'lowres, had anatomy, bad hands, cropped, worst quality',
|
327 |
+
11,
|
328 |
+
0.3,
|
329 |
+
0.3,
|
330 |
+
0.3,
|
331 |
+
6,
|
332 |
+
0.5,
|
333 |
+
None,
|
334 |
+
],
|
335 |
+
[
|
336 |
+
'{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
337 |
+
'lowres, had anatomy, bad hands, cropped, worst quality',
|
338 |
+
11,
|
339 |
+
0.3,
|
340 |
+
0.3,
|
341 |
+
0.3,
|
342 |
+
6,
|
343 |
+
0.5,
|
344 |
+
None,
|
345 |
+
],
|
346 |
+
[
|
347 |
+
'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
|
348 |
+
'',
|
349 |
+
10,
|
350 |
+
0.4,
|
351 |
+
0.5,
|
352 |
+
0.3,
|
353 |
+
6,
|
354 |
+
0.5,
|
355 |
+
None,
|
356 |
+
],
|
357 |
+
[
|
358 |
+
'{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
|
359 |
+
'',
|
360 |
+
3,
|
361 |
+
0.3,
|
362 |
+
0,
|
363 |
+
0,
|
364 |
+
9,
|
365 |
+
1,
|
366 |
+
None,
|
367 |
+
],
|
368 |
+
[
|
369 |
+
'{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
|
370 |
+
'',
|
371 |
+
5,
|
372 |
+
0.4,
|
373 |
+
0.3,
|
374 |
+
0.3,
|
375 |
+
5,
|
376 |
+
0.6,
|
377 |
+
None,
|
378 |
+
],
|
379 |
+
]
|
380 |
+
gr.Examples(examples=color_examples,
|
381 |
+
label='Font color examples',
|
382 |
+
inputs=[
|
383 |
+
text_input,
|
384 |
+
negative_prompt,
|
385 |
+
num_segments,
|
386 |
+
segment_threshold,
|
387 |
+
inject_interval,
|
388 |
+
inject_background,
|
389 |
+
seed,
|
390 |
+
color_guidance_weight,
|
391 |
+
rich_text_input,
|
392 |
+
],
|
393 |
+
outputs=[
|
394 |
+
plaintext_result,
|
395 |
+
richtext_result,
|
396 |
+
segments,
|
397 |
+
token_map,
|
398 |
+
],
|
399 |
+
fn=generate,
|
400 |
+
cache_examples=True,
|
401 |
+
examples_per_page=20)
|
402 |
+
|
403 |
+
with gr.Row():
|
404 |
+
style_examples = [
|
405 |
+
[
|
406 |
+
'{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
|
407 |
+
'',
|
408 |
+
10,
|
409 |
+
0.4,
|
410 |
+
0,
|
411 |
+
0.2,
|
412 |
+
3,
|
413 |
+
0,
|
414 |
+
None,
|
415 |
+
],
|
416 |
+
[
|
417 |
+
'{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
|
418 |
+
'worst quality, dark, poor quality',
|
419 |
+
5,
|
420 |
+
0.3,
|
421 |
+
0,
|
422 |
+
0,
|
423 |
+
9,
|
424 |
+
0.5,
|
425 |
+
None,
|
426 |
+
],
|
427 |
+
[
|
428 |
+
'{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
|
429 |
+
'',
|
430 |
+
2,
|
431 |
+
0.35,
|
432 |
+
0,
|
433 |
+
0,
|
434 |
+
6,
|
435 |
+
0.5,
|
436 |
+
None,
|
437 |
+
],
|
438 |
+
]
|
439 |
+
gr.Examples(examples=style_examples,
|
440 |
+
label='Font style examples',
|
441 |
+
inputs=[
|
442 |
+
text_input,
|
443 |
+
negative_prompt,
|
444 |
+
num_segments,
|
445 |
+
segment_threshold,
|
446 |
+
inject_interval,
|
447 |
+
inject_background,
|
448 |
+
seed,
|
449 |
+
color_guidance_weight,
|
450 |
+
rich_text_input,
|
451 |
+
],
|
452 |
+
outputs=[
|
453 |
+
plaintext_result,
|
454 |
+
richtext_result,
|
455 |
+
segments,
|
456 |
+
token_map,
|
457 |
+
],
|
458 |
+
fn=generate,
|
459 |
+
cache_examples=True,
|
460 |
+
examples_per_page=20)
|
461 |
+
|
462 |
+
with gr.Row():
|
463 |
+
size_examples = [
|
464 |
+
[
|
465 |
+
'{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
|
466 |
+
'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
|
467 |
+
5,
|
468 |
+
0.3,
|
469 |
+
0,
|
470 |
+
0,
|
471 |
+
13,
|
472 |
+
1,
|
473 |
+
None,
|
474 |
+
],
|
475 |
+
[
|
476 |
+
'{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
|
477 |
+
'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
|
478 |
+
5,
|
479 |
+
0.3,
|
480 |
+
0,
|
481 |
+
0,
|
482 |
+
13,
|
483 |
+
1,
|
484 |
+
None,
|
485 |
+
],
|
486 |
+
[
|
487 |
+
'{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
|
488 |
+
'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
|
489 |
+
5,
|
490 |
+
0.3,
|
491 |
+
0,
|
492 |
+
0,
|
493 |
+
13,
|
494 |
+
1,
|
495 |
+
None,
|
496 |
+
],
|
497 |
+
]
|
498 |
+
gr.Examples(examples=size_examples,
|
499 |
+
label='Font size examples',
|
500 |
+
inputs=[
|
501 |
+
text_input,
|
502 |
+
negative_prompt,
|
503 |
+
num_segments,
|
504 |
+
segment_threshold,
|
505 |
+
inject_interval,
|
506 |
+
inject_background,
|
507 |
+
seed,
|
508 |
+
color_guidance_weight,
|
509 |
+
rich_text_input,
|
510 |
+
],
|
511 |
+
outputs=[
|
512 |
+
plaintext_result,
|
513 |
+
richtext_result,
|
514 |
+
segments,
|
515 |
+
token_map,
|
516 |
+
],
|
517 |
+
fn=generate,
|
518 |
+
cache_examples=True,
|
519 |
+
examples_per_page=20)
|
520 |
+
generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
|
521 |
+
fn=generate,
|
522 |
+
inputs=[
|
523 |
+
text_input,
|
524 |
+
negative_prompt,
|
525 |
+
num_segments,
|
526 |
+
segment_threshold,
|
527 |
+
inject_interval,
|
528 |
+
inject_background,
|
529 |
+
seed,
|
530 |
+
color_guidance_weight,
|
531 |
+
rich_text_input,
|
532 |
+
height,
|
533 |
+
width,
|
534 |
+
steps,
|
535 |
+
guidance_weight,
|
536 |
+
],
|
537 |
+
outputs=[plaintext_result, richtext_result, segments, token_map],
|
538 |
+
_js=get_js_data
|
539 |
+
).then(
|
540 |
+
fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
|
541 |
+
text_input.change(
|
542 |
+
fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
|
543 |
+
# load url param prompt to textinput
|
544 |
+
load_params_button.click(fn=lambda x: x['prompt'], inputs=[
|
545 |
+
url_params], outputs=[text_input], queue=False)
|
546 |
+
demo.load(
|
547 |
+
fn=load_url_params,
|
548 |
+
inputs=[url_params],
|
549 |
+
outputs=[load_params_button, url_params],
|
550 |
+
_js=get_window_url_params
|
551 |
+
)
|
552 |
+
demo.queue(concurrency_count=1)
|
553 |
+
demo.launch(share=False)
|
554 |
+
|
555 |
+
|
556 |
+
if __name__ == "__main__":
|
557 |
+
main()
|
models/attention.py
ADDED
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
import warnings
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
25 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
26 |
+
from diffusers.utils import BaseOutput
|
27 |
+
from diffusers.utils.import_utils import is_xformers_available
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class Transformer2DModelOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
35 |
+
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
36 |
+
for the unnoised latent pixels.
|
37 |
+
"""
|
38 |
+
|
39 |
+
sample: torch.FloatTensor
|
40 |
+
|
41 |
+
|
42 |
+
if is_xformers_available():
|
43 |
+
import xformers
|
44 |
+
import xformers.ops
|
45 |
+
else:
|
46 |
+
xformers = None
|
47 |
+
|
48 |
+
|
49 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
50 |
+
"""
|
51 |
+
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
52 |
+
embeddings) inputs.
|
53 |
+
|
54 |
+
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
55 |
+
transformer action. Finally, reshape to image.
|
56 |
+
|
57 |
+
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
58 |
+
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
59 |
+
classes of unnoised image.
|
60 |
+
|
61 |
+
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
62 |
+
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
66 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
67 |
+
in_channels (`int`, *optional*):
|
68 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
69 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
70 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
71 |
+
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
72 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
73 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
74 |
+
`ImagePositionalEmbeddings`.
|
75 |
+
num_vector_embeds (`int`, *optional*):
|
76 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
77 |
+
Includes the class for the masked latent pixel.
|
78 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
79 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
80 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
81 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
82 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
83 |
+
attention_bias (`bool`, *optional*):
|
84 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
85 |
+
"""
|
86 |
+
|
87 |
+
@register_to_config
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
num_attention_heads: int = 16,
|
91 |
+
attention_head_dim: int = 88,
|
92 |
+
in_channels: Optional[int] = None,
|
93 |
+
num_layers: int = 1,
|
94 |
+
dropout: float = 0.0,
|
95 |
+
norm_num_groups: int = 32,
|
96 |
+
cross_attention_dim: Optional[int] = None,
|
97 |
+
attention_bias: bool = False,
|
98 |
+
sample_size: Optional[int] = None,
|
99 |
+
num_vector_embeds: Optional[int] = None,
|
100 |
+
activation_fn: str = "geglu",
|
101 |
+
num_embeds_ada_norm: Optional[int] = None,
|
102 |
+
use_linear_projection: bool = False,
|
103 |
+
only_cross_attention: bool = False,
|
104 |
+
):
|
105 |
+
super().__init__()
|
106 |
+
self.use_linear_projection = use_linear_projection
|
107 |
+
self.num_attention_heads = num_attention_heads
|
108 |
+
self.attention_head_dim = attention_head_dim
|
109 |
+
inner_dim = num_attention_heads * attention_head_dim
|
110 |
+
|
111 |
+
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
112 |
+
# Define whether input is continuous or discrete depending on configuration
|
113 |
+
self.is_input_continuous = in_channels is not None
|
114 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
115 |
+
|
116 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
117 |
+
raise ValueError(
|
118 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
119 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
120 |
+
)
|
121 |
+
elif not self.is_input_continuous and not self.is_input_vectorized:
|
122 |
+
raise ValueError(
|
123 |
+
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
124 |
+
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
125 |
+
)
|
126 |
+
|
127 |
+
# 2. Define input layers
|
128 |
+
if self.is_input_continuous:
|
129 |
+
self.in_channels = in_channels
|
130 |
+
|
131 |
+
self.norm = torch.nn.GroupNorm(
|
132 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
133 |
+
if use_linear_projection:
|
134 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
135 |
+
else:
|
136 |
+
self.proj_in = nn.Conv2d(
|
137 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
138 |
+
elif self.is_input_vectorized:
|
139 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
140 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
141 |
+
|
142 |
+
self.height = sample_size
|
143 |
+
self.width = sample_size
|
144 |
+
self.num_vector_embeds = num_vector_embeds
|
145 |
+
self.num_latent_pixels = self.height * self.width
|
146 |
+
|
147 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
148 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
149 |
+
)
|
150 |
+
|
151 |
+
# 3. Define transformers blocks
|
152 |
+
self.transformer_blocks = nn.ModuleList(
|
153 |
+
[
|
154 |
+
BasicTransformerBlock(
|
155 |
+
inner_dim,
|
156 |
+
num_attention_heads,
|
157 |
+
attention_head_dim,
|
158 |
+
dropout=dropout,
|
159 |
+
cross_attention_dim=cross_attention_dim,
|
160 |
+
activation_fn=activation_fn,
|
161 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
162 |
+
attention_bias=attention_bias,
|
163 |
+
only_cross_attention=only_cross_attention,
|
164 |
+
)
|
165 |
+
for d in range(num_layers)
|
166 |
+
]
|
167 |
+
)
|
168 |
+
|
169 |
+
# 4. Define output layers
|
170 |
+
if self.is_input_continuous:
|
171 |
+
if use_linear_projection:
|
172 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
173 |
+
else:
|
174 |
+
self.proj_out = nn.Conv2d(
|
175 |
+
inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
176 |
+
elif self.is_input_vectorized:
|
177 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
178 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
179 |
+
|
180 |
+
def _set_attention_slice(self, slice_size):
|
181 |
+
for block in self.transformer_blocks:
|
182 |
+
block._set_attention_slice(slice_size)
|
183 |
+
|
184 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None,
|
185 |
+
text_format_dict={}, return_dict: bool = True):
|
186 |
+
"""
|
187 |
+
Args:
|
188 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
189 |
+
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
190 |
+
hidden_states
|
191 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
192 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
193 |
+
self-attention.
|
194 |
+
timestep ( `torch.long`, *optional*):
|
195 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
196 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
197 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
201 |
+
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
202 |
+
tensor.
|
203 |
+
"""
|
204 |
+
# 1. Input
|
205 |
+
if self.is_input_continuous:
|
206 |
+
batch, channel, height, weight = hidden_states.shape
|
207 |
+
residual = hidden_states
|
208 |
+
|
209 |
+
hidden_states = self.norm(hidden_states)
|
210 |
+
if not self.use_linear_projection:
|
211 |
+
hidden_states = self.proj_in(hidden_states)
|
212 |
+
inner_dim = hidden_states.shape[1]
|
213 |
+
hidden_states = hidden_states.permute(
|
214 |
+
0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
215 |
+
else:
|
216 |
+
inner_dim = hidden_states.shape[1]
|
217 |
+
hidden_states = hidden_states.permute(
|
218 |
+
0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
219 |
+
hidden_states = self.proj_in(hidden_states)
|
220 |
+
elif self.is_input_vectorized:
|
221 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
222 |
+
|
223 |
+
# 2. Blocks
|
224 |
+
for block in self.transformer_blocks:
|
225 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep,
|
226 |
+
text_format_dict=text_format_dict)
|
227 |
+
|
228 |
+
# 3. Output
|
229 |
+
if self.is_input_continuous:
|
230 |
+
if not self.use_linear_projection:
|
231 |
+
hidden_states = (
|
232 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(
|
233 |
+
0, 3, 1, 2).contiguous()
|
234 |
+
)
|
235 |
+
hidden_states = self.proj_out(hidden_states)
|
236 |
+
else:
|
237 |
+
hidden_states = self.proj_out(hidden_states)
|
238 |
+
hidden_states = (
|
239 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(
|
240 |
+
0, 3, 1, 2).contiguous()
|
241 |
+
)
|
242 |
+
|
243 |
+
output = hidden_states + residual
|
244 |
+
elif self.is_input_vectorized:
|
245 |
+
hidden_states = self.norm_out(hidden_states)
|
246 |
+
logits = self.out(hidden_states)
|
247 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
248 |
+
logits = logits.permute(0, 2, 1)
|
249 |
+
|
250 |
+
# log(p(x_0))
|
251 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
252 |
+
|
253 |
+
if not return_dict:
|
254 |
+
return (output,)
|
255 |
+
|
256 |
+
return Transformer2DModelOutput(sample=output)
|
257 |
+
|
258 |
+
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
259 |
+
for block in self.transformer_blocks:
|
260 |
+
block._set_use_memory_efficient_attention_xformers(
|
261 |
+
use_memory_efficient_attention_xformers)
|
262 |
+
|
263 |
+
|
264 |
+
class AttentionBlock(nn.Module):
|
265 |
+
"""
|
266 |
+
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
267 |
+
to the N-d case.
|
268 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
269 |
+
Uses three q, k, v linear layers to compute attention.
|
270 |
+
|
271 |
+
Parameters:
|
272 |
+
channels (`int`): The number of channels in the input and output.
|
273 |
+
num_head_channels (`int`, *optional*):
|
274 |
+
The number of channels in each head. If None, then `num_heads` = 1.
|
275 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
276 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
277 |
+
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
channels: int,
|
283 |
+
num_head_channels: Optional[int] = None,
|
284 |
+
norm_num_groups: int = 32,
|
285 |
+
rescale_output_factor: float = 1.0,
|
286 |
+
eps: float = 1e-5,
|
287 |
+
):
|
288 |
+
super().__init__()
|
289 |
+
self.channels = channels
|
290 |
+
|
291 |
+
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
292 |
+
self.num_head_size = num_head_channels
|
293 |
+
self.group_norm = nn.GroupNorm(
|
294 |
+
num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
295 |
+
|
296 |
+
# define q,k,v as linear layers
|
297 |
+
self.query = nn.Linear(channels, channels)
|
298 |
+
self.key = nn.Linear(channels, channels)
|
299 |
+
self.value = nn.Linear(channels, channels)
|
300 |
+
|
301 |
+
self.rescale_output_factor = rescale_output_factor
|
302 |
+
self.proj_attn = nn.Linear(channels, channels, 1)
|
303 |
+
|
304 |
+
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
305 |
+
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
306 |
+
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
307 |
+
new_projection = projection.view(
|
308 |
+
new_projection_shape).permute(0, 2, 1, 3)
|
309 |
+
return new_projection
|
310 |
+
|
311 |
+
def forward(self, hidden_states):
|
312 |
+
residual = hidden_states
|
313 |
+
batch, channel, height, width = hidden_states.shape
|
314 |
+
|
315 |
+
# norm
|
316 |
+
hidden_states = self.group_norm(hidden_states)
|
317 |
+
|
318 |
+
hidden_states = hidden_states.view(
|
319 |
+
batch, channel, height * width).transpose(1, 2)
|
320 |
+
|
321 |
+
# proj to q, k, v
|
322 |
+
query_proj = self.query(hidden_states)
|
323 |
+
key_proj = self.key(hidden_states)
|
324 |
+
value_proj = self.value(hidden_states)
|
325 |
+
|
326 |
+
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
327 |
+
|
328 |
+
# get scores
|
329 |
+
if self.num_heads > 1:
|
330 |
+
query_states = self.transpose_for_scores(query_proj)
|
331 |
+
key_states = self.transpose_for_scores(key_proj)
|
332 |
+
value_states = self.transpose_for_scores(value_proj)
|
333 |
+
|
334 |
+
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
|
335 |
+
# or reformulate this into a 3D problem?
|
336 |
+
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
337 |
+
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
338 |
+
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
339 |
+
attention_scores = torch.matmul(
|
340 |
+
query_states, key_states.transpose(-1, -2)) * scale
|
341 |
+
else:
|
342 |
+
query_states, key_states, value_states = query_proj, key_proj, value_proj
|
343 |
+
|
344 |
+
attention_scores = torch.baddbmm(
|
345 |
+
torch.empty(
|
346 |
+
query_states.shape[0],
|
347 |
+
query_states.shape[1],
|
348 |
+
key_states.shape[1],
|
349 |
+
dtype=query_states.dtype,
|
350 |
+
device=query_states.device,
|
351 |
+
),
|
352 |
+
query_states,
|
353 |
+
key_states.transpose(-1, -2),
|
354 |
+
beta=0,
|
355 |
+
alpha=scale,
|
356 |
+
)
|
357 |
+
|
358 |
+
attention_probs = torch.softmax(
|
359 |
+
attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
360 |
+
|
361 |
+
# compute attention output
|
362 |
+
if self.num_heads > 1:
|
363 |
+
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
|
364 |
+
# or reformulate this into a 3D problem?
|
365 |
+
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
366 |
+
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
367 |
+
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
368 |
+
hidden_states = torch.matmul(attention_probs, value_states)
|
369 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
370 |
+
new_hidden_states_shape = hidden_states.size()[
|
371 |
+
:-2] + (self.channels,)
|
372 |
+
hidden_states = hidden_states.view(new_hidden_states_shape)
|
373 |
+
else:
|
374 |
+
hidden_states = torch.bmm(attention_probs, value_states)
|
375 |
+
|
376 |
+
# compute next hidden_states
|
377 |
+
hidden_states = self.proj_attn(hidden_states)
|
378 |
+
hidden_states = hidden_states.transpose(
|
379 |
+
-1, -2).reshape(batch, channel, height, width)
|
380 |
+
|
381 |
+
# res connect and rescale
|
382 |
+
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
383 |
+
return hidden_states
|
384 |
+
|
385 |
+
|
386 |
+
class BasicTransformerBlock(nn.Module):
|
387 |
+
r"""
|
388 |
+
A basic Transformer block.
|
389 |
+
|
390 |
+
Parameters:
|
391 |
+
dim (`int`): The number of channels in the input and output.
|
392 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
393 |
+
attention_head_dim (`int`): The number of channels in each head.
|
394 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
395 |
+
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
396 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
397 |
+
num_embeds_ada_norm (:
|
398 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
399 |
+
attention_bias (:
|
400 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
401 |
+
"""
|
402 |
+
|
403 |
+
def __init__(
|
404 |
+
self,
|
405 |
+
dim: int,
|
406 |
+
num_attention_heads: int,
|
407 |
+
attention_head_dim: int,
|
408 |
+
dropout=0.0,
|
409 |
+
cross_attention_dim: Optional[int] = None,
|
410 |
+
activation_fn: str = "geglu",
|
411 |
+
num_embeds_ada_norm: Optional[int] = None,
|
412 |
+
attention_bias: bool = False,
|
413 |
+
only_cross_attention: bool = False,
|
414 |
+
):
|
415 |
+
super().__init__()
|
416 |
+
self.only_cross_attention = only_cross_attention
|
417 |
+
self.attn1 = CrossAttention(
|
418 |
+
query_dim=dim,
|
419 |
+
heads=num_attention_heads,
|
420 |
+
dim_head=attention_head_dim,
|
421 |
+
dropout=dropout,
|
422 |
+
bias=attention_bias,
|
423 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
424 |
+
) # is a self-attention
|
425 |
+
self.ff = FeedForward(dim, dropout=dropout,
|
426 |
+
activation_fn=activation_fn)
|
427 |
+
self.attn2 = CrossAttention(
|
428 |
+
query_dim=dim,
|
429 |
+
cross_attention_dim=cross_attention_dim,
|
430 |
+
heads=num_attention_heads,
|
431 |
+
dim_head=attention_head_dim,
|
432 |
+
dropout=dropout,
|
433 |
+
bias=attention_bias,
|
434 |
+
) # is self-attn if context is none
|
435 |
+
|
436 |
+
# layer norms
|
437 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
438 |
+
if self.use_ada_layer_norm:
|
439 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
440 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
441 |
+
else:
|
442 |
+
self.norm1 = nn.LayerNorm(dim)
|
443 |
+
self.norm2 = nn.LayerNorm(dim)
|
444 |
+
self.norm3 = nn.LayerNorm(dim)
|
445 |
+
|
446 |
+
# if xformers is installed try to use memory_efficient_attention by default
|
447 |
+
if is_xformers_available():
|
448 |
+
try:
|
449 |
+
self._set_use_memory_efficient_attention_xformers(True)
|
450 |
+
except Exception as e:
|
451 |
+
warnings.warn(
|
452 |
+
"Could not enable memory efficient attention. Make sure xformers is installed"
|
453 |
+
f" correctly and a GPU is available: {e}"
|
454 |
+
)
|
455 |
+
|
456 |
+
def _set_attention_slice(self, slice_size):
|
457 |
+
self.attn1._slice_size = slice_size
|
458 |
+
self.attn2._slice_size = slice_size
|
459 |
+
|
460 |
+
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
461 |
+
if not is_xformers_available():
|
462 |
+
print("Here is how to install it")
|
463 |
+
raise ModuleNotFoundError(
|
464 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
465 |
+
" xformers",
|
466 |
+
name="xformers",
|
467 |
+
)
|
468 |
+
elif not torch.cuda.is_available():
|
469 |
+
raise ValueError(
|
470 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
471 |
+
" available for GPU "
|
472 |
+
)
|
473 |
+
else:
|
474 |
+
try:
|
475 |
+
# Make sure we can run the memory efficient attention
|
476 |
+
_ = xformers.ops.memory_efficient_attention(
|
477 |
+
torch.randn((1, 2, 40), device="cuda"),
|
478 |
+
torch.randn((1, 2, 40), device="cuda"),
|
479 |
+
torch.randn((1, 2, 40), device="cuda"),
|
480 |
+
)
|
481 |
+
except Exception as e:
|
482 |
+
raise e
|
483 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
484 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
485 |
+
|
486 |
+
def forward(self, hidden_states, context=None, timestep=None, text_format_dict={}):
|
487 |
+
# 1. Self-Attention
|
488 |
+
norm_hidden_states = (
|
489 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(
|
490 |
+
hidden_states)
|
491 |
+
)
|
492 |
+
|
493 |
+
if self.only_cross_attention:
|
494 |
+
attn_out, _ = self.attn1(
|
495 |
+
norm_hidden_states, context=context, text_format_dict=text_format_dict) + hidden_states
|
496 |
+
hidden_states = attn_out + hidden_states
|
497 |
+
else:
|
498 |
+
attn_out, _ = self.attn1(norm_hidden_states)
|
499 |
+
hidden_states = attn_out + hidden_states
|
500 |
+
|
501 |
+
# 2. Cross-Attention
|
502 |
+
norm_hidden_states = (
|
503 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(
|
504 |
+
hidden_states)
|
505 |
+
)
|
506 |
+
attn_out, _ = self.attn2(
|
507 |
+
norm_hidden_states, context=context, text_format_dict=text_format_dict)
|
508 |
+
hidden_states = attn_out + hidden_states
|
509 |
+
|
510 |
+
# 3. Feed-forward
|
511 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
512 |
+
|
513 |
+
return hidden_states
|
514 |
+
|
515 |
+
|
516 |
+
class CrossAttention(nn.Module):
|
517 |
+
r"""
|
518 |
+
A cross attention layer.
|
519 |
+
|
520 |
+
Parameters:
|
521 |
+
query_dim (`int`): The number of channels in the query.
|
522 |
+
cross_attention_dim (`int`, *optional*):
|
523 |
+
The number of channels in the context. If not given, defaults to `query_dim`.
|
524 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
525 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
526 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
527 |
+
bias (`bool`, *optional*, defaults to False):
|
528 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
529 |
+
"""
|
530 |
+
|
531 |
+
def __init__(
|
532 |
+
self,
|
533 |
+
query_dim: int,
|
534 |
+
cross_attention_dim: Optional[int] = None,
|
535 |
+
heads: int = 8,
|
536 |
+
dim_head: int = 64,
|
537 |
+
dropout: float = 0.0,
|
538 |
+
bias=False,
|
539 |
+
):
|
540 |
+
super().__init__()
|
541 |
+
inner_dim = dim_head * heads
|
542 |
+
self.is_cross_attn = cross_attention_dim is not None
|
543 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
544 |
+
|
545 |
+
self.scale = dim_head**-0.5
|
546 |
+
self.heads = heads
|
547 |
+
# for slice_size > 0 the attention score computation
|
548 |
+
# is split across the batch axis to save memory
|
549 |
+
# You can set slice_size with `set_attention_slice`
|
550 |
+
self._slice_size = None
|
551 |
+
self._use_memory_efficient_attention_xformers = False
|
552 |
+
|
553 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
554 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
555 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
556 |
+
|
557 |
+
self.to_out = nn.ModuleList([])
|
558 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
559 |
+
self.to_out.append(nn.Dropout(dropout))
|
560 |
+
|
561 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
562 |
+
batch_size, seq_len, dim = tensor.shape
|
563 |
+
head_size = self.heads
|
564 |
+
tensor = tensor.reshape(batch_size, seq_len,
|
565 |
+
head_size, dim // head_size)
|
566 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
567 |
+
batch_size * head_size, seq_len, dim // head_size)
|
568 |
+
return tensor
|
569 |
+
|
570 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
571 |
+
batch_size, seq_len, dim = tensor.shape
|
572 |
+
head_size = self.heads
|
573 |
+
tensor = tensor.reshape(batch_size // head_size,
|
574 |
+
head_size, seq_len, dim)
|
575 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
576 |
+
batch_size // head_size, seq_len, dim * head_size)
|
577 |
+
return tensor
|
578 |
+
|
579 |
+
def reshape_batch_dim_to_heads_and_average(self, tensor):
|
580 |
+
batch_size, seq_len, seq_len2 = tensor.shape
|
581 |
+
head_size = self.heads
|
582 |
+
tensor = tensor.reshape(batch_size // head_size,
|
583 |
+
head_size, seq_len, seq_len2)
|
584 |
+
return tensor.mean(1)
|
585 |
+
|
586 |
+
def forward(self, hidden_states, real_attn_probs=None, context=None, mask=None, text_format_dict={}):
|
587 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
588 |
+
|
589 |
+
query = self.to_q(hidden_states)
|
590 |
+
context = context if context is not None else hidden_states
|
591 |
+
key = self.to_k(context)
|
592 |
+
value = self.to_v(context)
|
593 |
+
|
594 |
+
dim = query.shape[-1]
|
595 |
+
|
596 |
+
query = self.reshape_heads_to_batch_dim(query)
|
597 |
+
key = self.reshape_heads_to_batch_dim(key)
|
598 |
+
value = self.reshape_heads_to_batch_dim(value)
|
599 |
+
|
600 |
+
# attention, what we cannot get enough of
|
601 |
+
if self._use_memory_efficient_attention_xformers:
|
602 |
+
hidden_states = self._memory_efficient_attention_xformers(
|
603 |
+
query, key, value)
|
604 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
605 |
+
hidden_states = hidden_states.to(query.dtype)
|
606 |
+
else:
|
607 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
608 |
+
# only this attention function is used
|
609 |
+
hidden_states, attn_probs = self._attention(
|
610 |
+
query, key, value, real_attn_probs, **text_format_dict)
|
611 |
+
|
612 |
+
# linear proj
|
613 |
+
hidden_states = self.to_out[0](hidden_states)
|
614 |
+
# dropout
|
615 |
+
hidden_states = self.to_out[1](hidden_states)
|
616 |
+
return hidden_states, attn_probs
|
617 |
+
|
618 |
+
def _qk(self, query, key):
|
619 |
+
return torch.baddbmm(
|
620 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1],
|
621 |
+
dtype=query.dtype, device=query.device),
|
622 |
+
query,
|
623 |
+
key.transpose(-1, -2),
|
624 |
+
beta=0,
|
625 |
+
alpha=self.scale,
|
626 |
+
)
|
627 |
+
|
628 |
+
def _attention(self, query, key, value, real_attn_probs=None, word_pos=None, font_size=None,
|
629 |
+
**kwargs):
|
630 |
+
attention_scores = self._qk(query, key)
|
631 |
+
|
632 |
+
# Font size V2:
|
633 |
+
if self.is_cross_attn and word_pos is not None and font_size is not None:
|
634 |
+
assert key.shape[1] == 77
|
635 |
+
attention_score_exp = attention_scores.exp()
|
636 |
+
font_size_abs, font_size_sign = font_size.abs(), font_size.sign()
|
637 |
+
attention_score_exp[:, :, word_pos] = attention_score_exp[:, :, word_pos].clone(
|
638 |
+
)*font_size_abs
|
639 |
+
attention_probs = attention_score_exp / \
|
640 |
+
attention_score_exp.sum(-1, True)
|
641 |
+
attention_probs[:, :, word_pos] *= font_size_sign
|
642 |
+
else:
|
643 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
644 |
+
|
645 |
+
# compute attention output
|
646 |
+
if real_attn_probs is None:
|
647 |
+
hidden_states = torch.bmm(attention_probs, value)
|
648 |
+
else:
|
649 |
+
if isinstance(real_attn_probs, dict):
|
650 |
+
for pos1, pos2 in zip(real_attn_probs['inject_pos'][0], real_attn_probs['inject_pos'][1]):
|
651 |
+
attention_probs[:, :,
|
652 |
+
pos2] = real_attn_probs['reference'][:, :, pos1]
|
653 |
+
hidden_states = torch.bmm(attention_probs, value)
|
654 |
+
else:
|
655 |
+
hidden_states = torch.bmm(real_attn_probs, value)
|
656 |
+
|
657 |
+
# reshape hidden_states
|
658 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
659 |
+
|
660 |
+
# we also return the map averaged over heads to save memory footprint
|
661 |
+
attention_probs_avg = self.reshape_batch_dim_to_heads_and_average(
|
662 |
+
attention_probs)
|
663 |
+
return hidden_states, [attention_probs_avg, attention_probs]
|
664 |
+
|
665 |
+
def _memory_efficient_attention_xformers(self, query, key, value):
|
666 |
+
query = query.contiguous()
|
667 |
+
key = key.contiguous()
|
668 |
+
value = value.contiguous()
|
669 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
670 |
+
query, key, value, attn_bias=None)
|
671 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
672 |
+
return hidden_states
|
673 |
+
|
674 |
+
|
675 |
+
class FeedForward(nn.Module):
|
676 |
+
r"""
|
677 |
+
A feed-forward layer.
|
678 |
+
|
679 |
+
Parameters:
|
680 |
+
dim (`int`): The number of channels in the input.
|
681 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
682 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
683 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
684 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
685 |
+
"""
|
686 |
+
|
687 |
+
def __init__(
|
688 |
+
self,
|
689 |
+
dim: int,
|
690 |
+
dim_out: Optional[int] = None,
|
691 |
+
mult: int = 4,
|
692 |
+
dropout: float = 0.0,
|
693 |
+
activation_fn: str = "geglu",
|
694 |
+
):
|
695 |
+
super().__init__()
|
696 |
+
inner_dim = int(dim * mult)
|
697 |
+
dim_out = dim_out if dim_out is not None else dim
|
698 |
+
|
699 |
+
if activation_fn == "geglu":
|
700 |
+
geglu = GEGLU(dim, inner_dim)
|
701 |
+
elif activation_fn == "geglu-approximate":
|
702 |
+
geglu = ApproximateGELU(dim, inner_dim)
|
703 |
+
|
704 |
+
self.net = nn.ModuleList([])
|
705 |
+
# project in
|
706 |
+
self.net.append(geglu)
|
707 |
+
# project dropout
|
708 |
+
self.net.append(nn.Dropout(dropout))
|
709 |
+
# project out
|
710 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
711 |
+
|
712 |
+
def forward(self, hidden_states):
|
713 |
+
for module in self.net:
|
714 |
+
hidden_states = module(hidden_states)
|
715 |
+
return hidden_states
|
716 |
+
|
717 |
+
|
718 |
+
# feedforward
|
719 |
+
class GEGLU(nn.Module):
|
720 |
+
r"""
|
721 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
722 |
+
|
723 |
+
Parameters:
|
724 |
+
dim_in (`int`): The number of channels in the input.
|
725 |
+
dim_out (`int`): The number of channels in the output.
|
726 |
+
"""
|
727 |
+
|
728 |
+
def __init__(self, dim_in: int, dim_out: int):
|
729 |
+
super().__init__()
|
730 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
731 |
+
|
732 |
+
def gelu(self, gate):
|
733 |
+
if gate.device.type != "mps":
|
734 |
+
return F.gelu(gate)
|
735 |
+
# mps: gelu is not implemented for float16
|
736 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
737 |
+
|
738 |
+
def forward(self, hidden_states):
|
739 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
740 |
+
return hidden_states * self.gelu(gate)
|
741 |
+
|
742 |
+
|
743 |
+
class ApproximateGELU(nn.Module):
|
744 |
+
"""
|
745 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
746 |
+
|
747 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
748 |
+
"""
|
749 |
+
|
750 |
+
def __init__(self, dim_in: int, dim_out: int):
|
751 |
+
super().__init__()
|
752 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
753 |
+
|
754 |
+
def forward(self, x):
|
755 |
+
x = self.proj(x)
|
756 |
+
return x * torch.sigmoid(1.702 * x)
|
757 |
+
|
758 |
+
|
759 |
+
class AdaLayerNorm(nn.Module):
|
760 |
+
"""
|
761 |
+
Norm layer modified to incorporate timestep embeddings.
|
762 |
+
"""
|
763 |
+
|
764 |
+
def __init__(self, embedding_dim, num_embeddings):
|
765 |
+
super().__init__()
|
766 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
767 |
+
self.silu = nn.SiLU()
|
768 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
769 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
770 |
+
|
771 |
+
def forward(self, x, timestep):
|
772 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
773 |
+
scale, shift = torch.chunk(emb, 2)
|
774 |
+
x = self.norm(x) * (1 + scale) + shift
|
775 |
+
return x
|
776 |
+
|
777 |
+
|
778 |
+
class DualTransformer2DModel(nn.Module):
|
779 |
+
"""
|
780 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
781 |
+
|
782 |
+
Parameters:
|
783 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
784 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
785 |
+
in_channels (`int`, *optional*):
|
786 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
787 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
788 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
789 |
+
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
790 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
791 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
792 |
+
`ImagePositionalEmbeddings`.
|
793 |
+
num_vector_embeds (`int`, *optional*):
|
794 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
795 |
+
Includes the class for the masked latent pixel.
|
796 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
797 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
798 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
799 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
800 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
801 |
+
attention_bias (`bool`, *optional*):
|
802 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
803 |
+
"""
|
804 |
+
|
805 |
+
def __init__(
|
806 |
+
self,
|
807 |
+
num_attention_heads: int = 16,
|
808 |
+
attention_head_dim: int = 88,
|
809 |
+
in_channels: Optional[int] = None,
|
810 |
+
num_layers: int = 1,
|
811 |
+
dropout: float = 0.0,
|
812 |
+
norm_num_groups: int = 32,
|
813 |
+
cross_attention_dim: Optional[int] = None,
|
814 |
+
attention_bias: bool = False,
|
815 |
+
sample_size: Optional[int] = None,
|
816 |
+
num_vector_embeds: Optional[int] = None,
|
817 |
+
activation_fn: str = "geglu",
|
818 |
+
num_embeds_ada_norm: Optional[int] = None,
|
819 |
+
):
|
820 |
+
super().__init__()
|
821 |
+
self.transformers = nn.ModuleList(
|
822 |
+
[
|
823 |
+
Transformer2DModel(
|
824 |
+
num_attention_heads=num_attention_heads,
|
825 |
+
attention_head_dim=attention_head_dim,
|
826 |
+
in_channels=in_channels,
|
827 |
+
num_layers=num_layers,
|
828 |
+
dropout=dropout,
|
829 |
+
norm_num_groups=norm_num_groups,
|
830 |
+
cross_attention_dim=cross_attention_dim,
|
831 |
+
attention_bias=attention_bias,
|
832 |
+
sample_size=sample_size,
|
833 |
+
num_vector_embeds=num_vector_embeds,
|
834 |
+
activation_fn=activation_fn,
|
835 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
836 |
+
)
|
837 |
+
for _ in range(2)
|
838 |
+
]
|
839 |
+
)
|
840 |
+
|
841 |
+
# Variables that can be set by a pipeline:
|
842 |
+
|
843 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
844 |
+
self.mix_ratio = 0.5
|
845 |
+
|
846 |
+
# The shape of `encoder_hidden_states` is expected to be
|
847 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
848 |
+
self.condition_lengths = [77, 257]
|
849 |
+
|
850 |
+
# Which transformer to use to encode which condition.
|
851 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
852 |
+
self.transformer_index_for_condition = [1, 0]
|
853 |
+
|
854 |
+
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
|
855 |
+
"""
|
856 |
+
Args:
|
857 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
858 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
859 |
+
hidden_states
|
860 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
861 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
862 |
+
self-attention.
|
863 |
+
timestep ( `torch.long`, *optional*):
|
864 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
865 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
866 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
867 |
+
|
868 |
+
Returns:
|
869 |
+
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
870 |
+
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
871 |
+
tensor.
|
872 |
+
"""
|
873 |
+
input_states = hidden_states
|
874 |
+
|
875 |
+
encoded_states = []
|
876 |
+
tokens_start = 0
|
877 |
+
for i in range(2):
|
878 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
879 |
+
condition_state = encoder_hidden_states[:,
|
880 |
+
tokens_start: tokens_start + self.condition_lengths[i]]
|
881 |
+
transformer_index = self.transformer_index_for_condition[i]
|
882 |
+
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
|
883 |
+
0
|
884 |
+
]
|
885 |
+
encoded_states.append(encoded_state - input_states)
|
886 |
+
tokens_start += self.condition_lengths[i]
|
887 |
+
|
888 |
+
output_states = encoded_states[0] * self.mix_ratio + \
|
889 |
+
encoded_states[1] * (1 - self.mix_ratio)
|
890 |
+
output_states = output_states + input_states
|
891 |
+
|
892 |
+
if not return_dict:
|
893 |
+
return (output_states,)
|
894 |
+
|
895 |
+
return Transformer2DModelOutput(sample=output_states)
|
896 |
+
|
897 |
+
def _set_attention_slice(self, slice_size):
|
898 |
+
for transformer in self.transformers:
|
899 |
+
transformer._set_attention_slice(slice_size)
|
900 |
+
|
901 |
+
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
902 |
+
for transformer in self.transformers:
|
903 |
+
transformer._set_use_memory_efficient_attention_xformers(
|
904 |
+
use_memory_efficient_attention_xformers)
|
models/region_diffusion.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import collections
|
4 |
+
import torch.nn as nn
|
5 |
+
from functools import partial
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
7 |
+
from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
8 |
+
from models.unet_2d_condition import UNet2DConditionModel
|
9 |
+
from utils.attention_utils import CrossAttentionLayers, SelfAttentionLayers
|
10 |
+
|
11 |
+
# suppress partial model loading warning
|
12 |
+
logging.set_verbosity_error()
|
13 |
+
|
14 |
+
|
15 |
+
class RegionDiffusion(nn.Module):
|
16 |
+
def __init__(self, device):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.device = device
|
20 |
+
self.num_train_timesteps = 1000
|
21 |
+
self.clip_gradient = False
|
22 |
+
|
23 |
+
print(f'[INFO] loading stable diffusion...')
|
24 |
+
model_id = 'runwayml/stable-diffusion-v1-5'
|
25 |
+
|
26 |
+
self.vae = AutoencoderKL.from_pretrained(
|
27 |
+
model_id, subfolder="vae").to(self.device)
|
28 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
29 |
+
model_id, subfolder='tokenizer')
|
30 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
31 |
+
model_id, subfolder='text_encoder').to(self.device)
|
32 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
33 |
+
model_id, subfolder="unet").to(self.device)
|
34 |
+
|
35 |
+
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
36 |
+
num_train_timesteps=self.num_train_timesteps, skip_prk_steps=True, steps_offset=1)
|
37 |
+
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
38 |
+
|
39 |
+
self.masks = []
|
40 |
+
self.attention_maps = None
|
41 |
+
self.selfattn_maps = None
|
42 |
+
self.crossattn_maps = None
|
43 |
+
self.color_loss = torch.nn.functional.mse_loss
|
44 |
+
self.forward_hooks = []
|
45 |
+
self.forward_replacement_hooks = []
|
46 |
+
|
47 |
+
print(f'[INFO] loaded stable diffusion!')
|
48 |
+
|
49 |
+
def get_text_embeds(self, prompt, negative_prompt):
|
50 |
+
# prompt, negative_prompt: [str]
|
51 |
+
|
52 |
+
# Tokenize text and get embeddings
|
53 |
+
text_input = self.tokenizer(
|
54 |
+
prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
text_embeddings = self.text_encoder(
|
58 |
+
text_input.input_ids.to(self.device))[0]
|
59 |
+
|
60 |
+
# Do the same for unconditional embeddings
|
61 |
+
uncond_input = self.tokenizer(negative_prompt, padding='max_length',
|
62 |
+
max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
uncond_embeddings = self.text_encoder(
|
66 |
+
uncond_input.input_ids.to(self.device))[0]
|
67 |
+
|
68 |
+
# Cat for final embeddings
|
69 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
70 |
+
return text_embeddings
|
71 |
+
|
72 |
+
def get_text_embeds_list(self, prompts):
|
73 |
+
# prompts: [list]
|
74 |
+
text_embeddings = []
|
75 |
+
for prompt in prompts:
|
76 |
+
# Tokenize text and get embeddings
|
77 |
+
text_input = self.tokenizer(
|
78 |
+
[prompt], padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
text_embeddings.append(self.text_encoder(
|
82 |
+
text_input.input_ids.to(self.device))[0])
|
83 |
+
|
84 |
+
return text_embeddings
|
85 |
+
|
86 |
+
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
|
87 |
+
latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, inject_background=0):
|
88 |
+
|
89 |
+
if latents is None:
|
90 |
+
latents = torch.randn(
|
91 |
+
(1, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
92 |
+
|
93 |
+
if inject_selfattn > 0 or inject_background > 0:
|
94 |
+
latents_reference = latents.clone().detach()
|
95 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
96 |
+
n_styles = text_embeddings.shape[0]-1
|
97 |
+
assert n_styles == len(self.masks)
|
98 |
+
with torch.autocast('cuda'):
|
99 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
100 |
+
|
101 |
+
# predict the noise residual
|
102 |
+
with torch.no_grad():
|
103 |
+
# tokens without any attributes
|
104 |
+
feat_inject_step = t > (1-inject_selfattn) * 1000
|
105 |
+
background_inject_step = i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0
|
106 |
+
noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
|
107 |
+
text_format_dict={})['sample']
|
108 |
+
noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
|
109 |
+
text_format_dict=text_format_dict)['sample']
|
110 |
+
if inject_selfattn > 0 or inject_background > 0:
|
111 |
+
noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
|
112 |
+
text_format_dict={})['sample']
|
113 |
+
self.register_selfattn_hooks(feat_inject_step)
|
114 |
+
noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:],
|
115 |
+
text_format_dict={})['sample']
|
116 |
+
self.remove_selfattn_hooks()
|
117 |
+
noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
|
118 |
+
noise_pred_text = noise_pred_text_cur * self.masks[-1]
|
119 |
+
# tokens with attributes
|
120 |
+
for style_i, mask in enumerate(self.masks[:-1]):
|
121 |
+
self.register_replacement_hooks(feat_inject_step)
|
122 |
+
noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
|
123 |
+
text_format_dict={})['sample']
|
124 |
+
self.remove_replacement_hooks()
|
125 |
+
noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
|
126 |
+
noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
|
127 |
+
|
128 |
+
# perform classifier-free guidance
|
129 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
130 |
+
(noise_pred_text - noise_pred_uncond)
|
131 |
+
|
132 |
+
if inject_selfattn > 0 or inject_background > 0:
|
133 |
+
noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
|
134 |
+
(noise_pred_text_refer - noise_pred_uncond_refer)
|
135 |
+
|
136 |
+
# compute the previous noisy sample x_t -> x_t-1
|
137 |
+
latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
|
138 |
+
torch.cat([latents, latents_reference]))[
|
139 |
+
'prev_sample']
|
140 |
+
latents, latents_reference = torch.chunk(
|
141 |
+
latents_reference, 2, dim=0)
|
142 |
+
|
143 |
+
else:
|
144 |
+
# compute the previous noisy sample x_t -> x_t-1
|
145 |
+
latents = self.scheduler.step(noise_pred, t, latents)[
|
146 |
+
'prev_sample']
|
147 |
+
|
148 |
+
# apply guidance
|
149 |
+
if use_guidance and t < text_format_dict['guidance_start_step']:
|
150 |
+
with torch.enable_grad():
|
151 |
+
if not latents.requires_grad:
|
152 |
+
latents.requires_grad = True
|
153 |
+
latents_0 = self.predict_x0(latents, noise_pred, t)
|
154 |
+
latents_inp = 1 / 0.18215 * latents_0
|
155 |
+
imgs = self.vae.decode(latents_inp).sample
|
156 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
157 |
+
loss_total = 0.
|
158 |
+
for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
|
159 |
+
avg_rgb = (
|
160 |
+
imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
|
161 |
+
loss = self.color_loss(
|
162 |
+
avg_rgb, rgb_val[:, :, 0, 0])*100
|
163 |
+
loss_total += loss
|
164 |
+
loss_total.backward()
|
165 |
+
latents = (
|
166 |
+
latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone()
|
167 |
+
|
168 |
+
# apply background injection
|
169 |
+
if background_inject_step:
|
170 |
+
latents = latents_reference * self.masks[-1] + latents * \
|
171 |
+
(1-self.masks[-1])
|
172 |
+
return latents
|
173 |
+
|
174 |
+
def predict_x0(self, x_t, eps_t, t):
|
175 |
+
alpha_t = self.scheduler.alphas_cumprod[t]
|
176 |
+
return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
|
177 |
+
|
178 |
+
def produce_attn_maps(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
|
179 |
+
guidance_scale=7.5, latents=None):
|
180 |
+
|
181 |
+
if isinstance(prompts, str):
|
182 |
+
prompts = [prompts]
|
183 |
+
|
184 |
+
if isinstance(negative_prompts, str):
|
185 |
+
negative_prompts = [negative_prompts]
|
186 |
+
|
187 |
+
# Prompts -> text embeds
|
188 |
+
text_embeddings = self.get_text_embeds(
|
189 |
+
prompts, negative_prompts) # [2, 77, 768]
|
190 |
+
if latents is None:
|
191 |
+
latents = torch.randn(
|
192 |
+
(text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
193 |
+
|
194 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
195 |
+
self.remove_replacement_hooks()
|
196 |
+
|
197 |
+
with torch.autocast('cuda'):
|
198 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
199 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
200 |
+
latent_model_input = torch.cat([latents] * 2)
|
201 |
+
|
202 |
+
# predict the noise residual
|
203 |
+
with torch.no_grad():
|
204 |
+
noise_pred = self.unet(
|
205 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
206 |
+
|
207 |
+
# perform guidance
|
208 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
209 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
210 |
+
(noise_pred_text - noise_pred_uncond)
|
211 |
+
|
212 |
+
# compute the previous noisy sample x_t -> x_t-1
|
213 |
+
latents = self.scheduler.step(noise_pred, t, latents)[
|
214 |
+
'prev_sample']
|
215 |
+
|
216 |
+
# Img latents -> imgs
|
217 |
+
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
218 |
+
|
219 |
+
# Img to Numpy
|
220 |
+
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
221 |
+
imgs = (imgs * 255).round().astype('uint8')
|
222 |
+
|
223 |
+
return imgs
|
224 |
+
|
225 |
+
def decode_latents(self, latents):
|
226 |
+
|
227 |
+
latents = 1 / 0.18215 * latents
|
228 |
+
|
229 |
+
with torch.no_grad():
|
230 |
+
imgs = self.vae.decode(latents).sample
|
231 |
+
|
232 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
233 |
+
|
234 |
+
return imgs
|
235 |
+
|
236 |
+
def encode_imgs(self, imgs):
|
237 |
+
# imgs: [B, 3, H, W]
|
238 |
+
|
239 |
+
imgs = 2 * imgs - 1
|
240 |
+
|
241 |
+
posterior = self.vae.encode(imgs).latent_dist
|
242 |
+
latents = posterior.sample() * 0.18215
|
243 |
+
|
244 |
+
return latents
|
245 |
+
|
246 |
+
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
|
247 |
+
guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, inject_background=0):
|
248 |
+
|
249 |
+
if isinstance(prompts, str):
|
250 |
+
prompts = [prompts]
|
251 |
+
|
252 |
+
if isinstance(negative_prompts, str):
|
253 |
+
negative_prompts = [negative_prompts]
|
254 |
+
|
255 |
+
# Prompts -> text embeds
|
256 |
+
text_embeds = self.get_text_embeds(
|
257 |
+
prompts, negative_prompts) # [2, 77, 768]
|
258 |
+
|
259 |
+
# else:
|
260 |
+
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
|
261 |
+
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
|
262 |
+
use_guidance=use_guidance, text_format_dict=text_format_dict,
|
263 |
+
inject_selfattn=inject_selfattn, inject_background=inject_background) # [1, 4, 64, 64]
|
264 |
+
# Img latents -> imgs
|
265 |
+
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
266 |
+
|
267 |
+
# Img to Numpy
|
268 |
+
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
269 |
+
imgs = (imgs * 255).round().astype('uint8')
|
270 |
+
|
271 |
+
return imgs
|
272 |
+
|
273 |
+
def reset_attention_maps(self):
|
274 |
+
r"""Function to reset attention maps.
|
275 |
+
We reset attention maps because we append them while getting hooks
|
276 |
+
to visualize attention maps for every step.
|
277 |
+
"""
|
278 |
+
for key in self.selfattn_maps:
|
279 |
+
self.selfattn_maps[key] = []
|
280 |
+
for key in self.crossattn_maps:
|
281 |
+
self.crossattn_maps[key] = []
|
282 |
+
|
283 |
+
def register_evaluation_hooks(self):
|
284 |
+
r"""Function for registering hooks during evaluation.
|
285 |
+
We mainly store activation maps averaged over queries.
|
286 |
+
"""
|
287 |
+
self.forward_hooks = []
|
288 |
+
|
289 |
+
def save_activations(activations, name, module, inp, out):
|
290 |
+
r"""
|
291 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
292 |
+
"""
|
293 |
+
# out[0] - final output of attention layer
|
294 |
+
# out[1] - attention probability matrix
|
295 |
+
if 'attn2' in name:
|
296 |
+
assert out[1].shape[-1] == 77
|
297 |
+
activations[name].append(out[1].detach().cpu())
|
298 |
+
else:
|
299 |
+
assert out[1].shape[-1] != 77
|
300 |
+
attention_dict = collections.defaultdict(list)
|
301 |
+
for name, module in self.unet.named_modules():
|
302 |
+
leaf_name = name.split('.')[-1]
|
303 |
+
if 'attn' in leaf_name:
|
304 |
+
# Register hook to obtain outputs at every attention layer.
|
305 |
+
self.forward_hooks.append(module.register_forward_hook(
|
306 |
+
partial(save_activations, attention_dict, name)
|
307 |
+
))
|
308 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
309 |
+
self.attention_maps = attention_dict
|
310 |
+
|
311 |
+
def register_selfattn_hooks(self, feat_inject_step=False):
|
312 |
+
r"""Function for registering hooks during evaluation.
|
313 |
+
We mainly store activation maps averaged over queries.
|
314 |
+
"""
|
315 |
+
self.selfattn_forward_hooks = []
|
316 |
+
|
317 |
+
def save_activations(activations, name, module, inp, out):
|
318 |
+
r"""
|
319 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
320 |
+
"""
|
321 |
+
# out[0] - final output of attention layer
|
322 |
+
# out[1] - attention probability matrix
|
323 |
+
if 'attn2' in name:
|
324 |
+
assert out[1][1].shape[-1] == 77
|
325 |
+
# cross attention injection
|
326 |
+
# activations[name] = out[1][1].detach()
|
327 |
+
else:
|
328 |
+
assert out[1][1].shape[-1] != 77
|
329 |
+
activations[name] = out[1][1].detach()
|
330 |
+
|
331 |
+
def save_resnet_activations(activations, name, module, inp, out):
|
332 |
+
r"""
|
333 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
334 |
+
"""
|
335 |
+
# out[0] - final output of residual layer
|
336 |
+
# out[1] - residual hidden feature
|
337 |
+
assert out[1].shape[-1] == 16
|
338 |
+
activations[name] = out[1].detach()
|
339 |
+
attention_dict = collections.defaultdict(list)
|
340 |
+
for name, module in self.unet.named_modules():
|
341 |
+
leaf_name = name.split('.')[-1]
|
342 |
+
if 'attn' in leaf_name and feat_inject_step:
|
343 |
+
# Register hook to obtain outputs at every attention layer.
|
344 |
+
self.selfattn_forward_hooks.append(module.register_forward_hook(
|
345 |
+
partial(save_activations, attention_dict, name)
|
346 |
+
))
|
347 |
+
if name == 'up_blocks.1.resnets.1' and feat_inject_step:
|
348 |
+
self.selfattn_forward_hooks.append(module.register_forward_hook(
|
349 |
+
partial(save_resnet_activations, attention_dict, name)
|
350 |
+
))
|
351 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
352 |
+
self.self_attention_maps_cur = attention_dict
|
353 |
+
|
354 |
+
def register_replacement_hooks(self, feat_inject_step=False):
|
355 |
+
r"""Function for registering hooks to replace self attention.
|
356 |
+
"""
|
357 |
+
self.forward_replacement_hooks = []
|
358 |
+
|
359 |
+
def replace_activations(name, module, args):
|
360 |
+
r"""
|
361 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
362 |
+
"""
|
363 |
+
if 'attn1' in name:
|
364 |
+
modified_args = (args[0], self.self_attention_maps_cur[name])
|
365 |
+
return modified_args
|
366 |
+
# cross attention injection
|
367 |
+
# elif 'attn2' in name:
|
368 |
+
# modified_map = {
|
369 |
+
# 'reference': self.self_attention_maps_cur[name],
|
370 |
+
# 'inject_pos': self.inject_pos,
|
371 |
+
# }
|
372 |
+
# modified_args = (args[0], modified_map)
|
373 |
+
# return modified_args
|
374 |
+
|
375 |
+
def replace_resnet_activations(name, module, args):
|
376 |
+
r"""
|
377 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
378 |
+
"""
|
379 |
+
modified_args = (args[0], args[1],
|
380 |
+
self.self_attention_maps_cur[name])
|
381 |
+
return modified_args
|
382 |
+
for name, module in self.unet.named_modules():
|
383 |
+
leaf_name = name.split('.')[-1]
|
384 |
+
if 'attn' in leaf_name and feat_inject_step:
|
385 |
+
# Register hook to obtain outputs at every attention layer.
|
386 |
+
self.forward_replacement_hooks.append(module.register_forward_pre_hook(
|
387 |
+
partial(replace_activations, name)
|
388 |
+
))
|
389 |
+
if name == 'up_blocks.1.resnets.1' and feat_inject_step:
|
390 |
+
# Register hook to obtain outputs at every attention layer.
|
391 |
+
self.forward_replacement_hooks.append(module.register_forward_pre_hook(
|
392 |
+
partial(replace_resnet_activations, name)
|
393 |
+
))
|
394 |
+
|
395 |
+
def register_tokenmap_hooks(self):
|
396 |
+
r"""Function for registering hooks during evaluation.
|
397 |
+
We mainly store activation maps averaged over queries.
|
398 |
+
"""
|
399 |
+
self.forward_hooks = []
|
400 |
+
|
401 |
+
def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
|
402 |
+
r"""
|
403 |
+
PyTorch Forward hook to save outputs at each forward pass.
|
404 |
+
"""
|
405 |
+
# out[0] - final output of attention layer
|
406 |
+
# out[1] - attention probability matrices
|
407 |
+
if name in n_maps:
|
408 |
+
n_maps[name] += 1
|
409 |
+
else:
|
410 |
+
n_maps[name] = 1
|
411 |
+
if 'attn2' in name:
|
412 |
+
assert out[1][0].shape[-1] == 77
|
413 |
+
if name in CrossAttentionLayers and n_maps[name] > 10:
|
414 |
+
if name in crossattn_maps:
|
415 |
+
crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
|
416 |
+
else:
|
417 |
+
crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
|
418 |
+
else:
|
419 |
+
assert out[1][0].shape[-1] != 77
|
420 |
+
if name in SelfAttentionLayers and n_maps[name] > 10:
|
421 |
+
if name in crossattn_maps:
|
422 |
+
selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
|
423 |
+
else:
|
424 |
+
selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
|
425 |
+
|
426 |
+
selfattn_maps = collections.defaultdict(list)
|
427 |
+
crossattn_maps = collections.defaultdict(list)
|
428 |
+
n_maps = collections.defaultdict(list)
|
429 |
+
|
430 |
+
for name, module in self.unet.named_modules():
|
431 |
+
leaf_name = name.split('.')[-1]
|
432 |
+
if 'attn' in leaf_name:
|
433 |
+
# Register hook to obtain outputs at every attention layer.
|
434 |
+
self.forward_hooks.append(module.register_forward_hook(
|
435 |
+
partial(save_activations, selfattn_maps,
|
436 |
+
crossattn_maps, n_maps, name)
|
437 |
+
))
|
438 |
+
# attention_dict is a dictionary containing attention maps for every attention layer
|
439 |
+
self.selfattn_maps = selfattn_maps
|
440 |
+
self.crossattn_maps = crossattn_maps
|
441 |
+
self.n_maps = n_maps
|
442 |
+
|
443 |
+
def remove_tokenmap_hooks(self):
|
444 |
+
for hook in self.forward_hooks:
|
445 |
+
hook.remove()
|
446 |
+
self.selfattn_maps = None
|
447 |
+
self.crossattn_maps = None
|
448 |
+
self.n_maps = None
|
449 |
+
|
450 |
+
def remove_evaluation_hooks(self):
|
451 |
+
for hook in self.forward_hooks:
|
452 |
+
hook.remove()
|
453 |
+
self.attention_maps = None
|
454 |
+
|
455 |
+
def remove_replacement_hooks(self):
|
456 |
+
for hook in self.forward_replacement_hooks:
|
457 |
+
hook.remove()
|
458 |
+
|
459 |
+
def remove_selfattn_hooks(self):
|
460 |
+
for hook in self.selfattn_forward_hooks:
|
461 |
+
hook.remove()
|
models/unet_2d_blocks.py
ADDED
@@ -0,0 +1,1855 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
|
19 |
+
from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, Upsample2D
|
20 |
+
|
21 |
+
|
22 |
+
def get_down_block(
|
23 |
+
down_block_type,
|
24 |
+
num_layers,
|
25 |
+
in_channels,
|
26 |
+
out_channels,
|
27 |
+
temb_channels,
|
28 |
+
add_downsample,
|
29 |
+
resnet_eps,
|
30 |
+
resnet_act_fn,
|
31 |
+
attn_num_head_channels,
|
32 |
+
resnet_groups=None,
|
33 |
+
cross_attention_dim=None,
|
34 |
+
downsample_padding=None,
|
35 |
+
dual_cross_attention=False,
|
36 |
+
use_linear_projection=False,
|
37 |
+
only_cross_attention=False,
|
38 |
+
):
|
39 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith(
|
40 |
+
"UNetRes") else down_block_type
|
41 |
+
if down_block_type == "DownBlock2D":
|
42 |
+
return DownBlock2D(
|
43 |
+
num_layers=num_layers,
|
44 |
+
in_channels=in_channels,
|
45 |
+
out_channels=out_channels,
|
46 |
+
temb_channels=temb_channels,
|
47 |
+
add_downsample=add_downsample,
|
48 |
+
resnet_eps=resnet_eps,
|
49 |
+
resnet_act_fn=resnet_act_fn,
|
50 |
+
resnet_groups=resnet_groups,
|
51 |
+
downsample_padding=downsample_padding,
|
52 |
+
)
|
53 |
+
elif down_block_type == "AttnDownBlock2D":
|
54 |
+
return AttnDownBlock2D(
|
55 |
+
num_layers=num_layers,
|
56 |
+
in_channels=in_channels,
|
57 |
+
out_channels=out_channels,
|
58 |
+
temb_channels=temb_channels,
|
59 |
+
add_downsample=add_downsample,
|
60 |
+
resnet_eps=resnet_eps,
|
61 |
+
resnet_act_fn=resnet_act_fn,
|
62 |
+
resnet_groups=resnet_groups,
|
63 |
+
downsample_padding=downsample_padding,
|
64 |
+
attn_num_head_channels=attn_num_head_channels,
|
65 |
+
)
|
66 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
67 |
+
if cross_attention_dim is None:
|
68 |
+
raise ValueError(
|
69 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
70 |
+
return CrossAttnDownBlock2D(
|
71 |
+
num_layers=num_layers,
|
72 |
+
in_channels=in_channels,
|
73 |
+
out_channels=out_channels,
|
74 |
+
temb_channels=temb_channels,
|
75 |
+
add_downsample=add_downsample,
|
76 |
+
resnet_eps=resnet_eps,
|
77 |
+
resnet_act_fn=resnet_act_fn,
|
78 |
+
resnet_groups=resnet_groups,
|
79 |
+
downsample_padding=downsample_padding,
|
80 |
+
cross_attention_dim=cross_attention_dim,
|
81 |
+
attn_num_head_channels=attn_num_head_channels,
|
82 |
+
dual_cross_attention=dual_cross_attention,
|
83 |
+
use_linear_projection=use_linear_projection,
|
84 |
+
only_cross_attention=only_cross_attention,
|
85 |
+
)
|
86 |
+
elif down_block_type == "SkipDownBlock2D":
|
87 |
+
return SkipDownBlock2D(
|
88 |
+
num_layers=num_layers,
|
89 |
+
in_channels=in_channels,
|
90 |
+
out_channels=out_channels,
|
91 |
+
temb_channels=temb_channels,
|
92 |
+
add_downsample=add_downsample,
|
93 |
+
resnet_eps=resnet_eps,
|
94 |
+
resnet_act_fn=resnet_act_fn,
|
95 |
+
downsample_padding=downsample_padding,
|
96 |
+
)
|
97 |
+
elif down_block_type == "AttnSkipDownBlock2D":
|
98 |
+
return AttnSkipDownBlock2D(
|
99 |
+
num_layers=num_layers,
|
100 |
+
in_channels=in_channels,
|
101 |
+
out_channels=out_channels,
|
102 |
+
temb_channels=temb_channels,
|
103 |
+
add_downsample=add_downsample,
|
104 |
+
resnet_eps=resnet_eps,
|
105 |
+
resnet_act_fn=resnet_act_fn,
|
106 |
+
downsample_padding=downsample_padding,
|
107 |
+
attn_num_head_channels=attn_num_head_channels,
|
108 |
+
)
|
109 |
+
elif down_block_type == "DownEncoderBlock2D":
|
110 |
+
return DownEncoderBlock2D(
|
111 |
+
num_layers=num_layers,
|
112 |
+
in_channels=in_channels,
|
113 |
+
out_channels=out_channels,
|
114 |
+
add_downsample=add_downsample,
|
115 |
+
resnet_eps=resnet_eps,
|
116 |
+
resnet_act_fn=resnet_act_fn,
|
117 |
+
resnet_groups=resnet_groups,
|
118 |
+
downsample_padding=downsample_padding,
|
119 |
+
)
|
120 |
+
elif down_block_type == "AttnDownEncoderBlock2D":
|
121 |
+
return AttnDownEncoderBlock2D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
add_downsample=add_downsample,
|
126 |
+
resnet_eps=resnet_eps,
|
127 |
+
resnet_act_fn=resnet_act_fn,
|
128 |
+
resnet_groups=resnet_groups,
|
129 |
+
downsample_padding=downsample_padding,
|
130 |
+
attn_num_head_channels=attn_num_head_channels,
|
131 |
+
)
|
132 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
133 |
+
|
134 |
+
|
135 |
+
def get_up_block(
|
136 |
+
up_block_type,
|
137 |
+
num_layers,
|
138 |
+
in_channels,
|
139 |
+
out_channels,
|
140 |
+
prev_output_channel,
|
141 |
+
temb_channels,
|
142 |
+
add_upsample,
|
143 |
+
resnet_eps,
|
144 |
+
resnet_act_fn,
|
145 |
+
attn_num_head_channels,
|
146 |
+
resnet_groups=None,
|
147 |
+
cross_attention_dim=None,
|
148 |
+
dual_cross_attention=False,
|
149 |
+
use_linear_projection=False,
|
150 |
+
only_cross_attention=False,
|
151 |
+
):
|
152 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith(
|
153 |
+
"UNetRes") else up_block_type
|
154 |
+
if up_block_type == "UpBlock2D":
|
155 |
+
return UpBlock2D(
|
156 |
+
num_layers=num_layers,
|
157 |
+
in_channels=in_channels,
|
158 |
+
out_channels=out_channels,
|
159 |
+
prev_output_channel=prev_output_channel,
|
160 |
+
temb_channels=temb_channels,
|
161 |
+
add_upsample=add_upsample,
|
162 |
+
resnet_eps=resnet_eps,
|
163 |
+
resnet_act_fn=resnet_act_fn,
|
164 |
+
resnet_groups=resnet_groups,
|
165 |
+
)
|
166 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
167 |
+
if cross_attention_dim is None:
|
168 |
+
raise ValueError(
|
169 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
170 |
+
return CrossAttnUpBlock2D(
|
171 |
+
num_layers=num_layers,
|
172 |
+
in_channels=in_channels,
|
173 |
+
out_channels=out_channels,
|
174 |
+
prev_output_channel=prev_output_channel,
|
175 |
+
temb_channels=temb_channels,
|
176 |
+
add_upsample=add_upsample,
|
177 |
+
resnet_eps=resnet_eps,
|
178 |
+
resnet_act_fn=resnet_act_fn,
|
179 |
+
resnet_groups=resnet_groups,
|
180 |
+
cross_attention_dim=cross_attention_dim,
|
181 |
+
attn_num_head_channels=attn_num_head_channels,
|
182 |
+
dual_cross_attention=dual_cross_attention,
|
183 |
+
use_linear_projection=use_linear_projection,
|
184 |
+
only_cross_attention=only_cross_attention,
|
185 |
+
)
|
186 |
+
elif up_block_type == "AttnUpBlock2D":
|
187 |
+
return AttnUpBlock2D(
|
188 |
+
num_layers=num_layers,
|
189 |
+
in_channels=in_channels,
|
190 |
+
out_channels=out_channels,
|
191 |
+
prev_output_channel=prev_output_channel,
|
192 |
+
temb_channels=temb_channels,
|
193 |
+
add_upsample=add_upsample,
|
194 |
+
resnet_eps=resnet_eps,
|
195 |
+
resnet_act_fn=resnet_act_fn,
|
196 |
+
resnet_groups=resnet_groups,
|
197 |
+
attn_num_head_channels=attn_num_head_channels,
|
198 |
+
)
|
199 |
+
elif up_block_type == "SkipUpBlock2D":
|
200 |
+
return SkipUpBlock2D(
|
201 |
+
num_layers=num_layers,
|
202 |
+
in_channels=in_channels,
|
203 |
+
out_channels=out_channels,
|
204 |
+
prev_output_channel=prev_output_channel,
|
205 |
+
temb_channels=temb_channels,
|
206 |
+
add_upsample=add_upsample,
|
207 |
+
resnet_eps=resnet_eps,
|
208 |
+
resnet_act_fn=resnet_act_fn,
|
209 |
+
)
|
210 |
+
elif up_block_type == "AttnSkipUpBlock2D":
|
211 |
+
return AttnSkipUpBlock2D(
|
212 |
+
num_layers=num_layers,
|
213 |
+
in_channels=in_channels,
|
214 |
+
out_channels=out_channels,
|
215 |
+
prev_output_channel=prev_output_channel,
|
216 |
+
temb_channels=temb_channels,
|
217 |
+
add_upsample=add_upsample,
|
218 |
+
resnet_eps=resnet_eps,
|
219 |
+
resnet_act_fn=resnet_act_fn,
|
220 |
+
attn_num_head_channels=attn_num_head_channels,
|
221 |
+
)
|
222 |
+
elif up_block_type == "UpDecoderBlock2D":
|
223 |
+
return UpDecoderBlock2D(
|
224 |
+
num_layers=num_layers,
|
225 |
+
in_channels=in_channels,
|
226 |
+
out_channels=out_channels,
|
227 |
+
add_upsample=add_upsample,
|
228 |
+
resnet_eps=resnet_eps,
|
229 |
+
resnet_act_fn=resnet_act_fn,
|
230 |
+
resnet_groups=resnet_groups,
|
231 |
+
)
|
232 |
+
elif up_block_type == "AttnUpDecoderBlock2D":
|
233 |
+
return AttnUpDecoderBlock2D(
|
234 |
+
num_layers=num_layers,
|
235 |
+
in_channels=in_channels,
|
236 |
+
out_channels=out_channels,
|
237 |
+
add_upsample=add_upsample,
|
238 |
+
resnet_eps=resnet_eps,
|
239 |
+
resnet_act_fn=resnet_act_fn,
|
240 |
+
resnet_groups=resnet_groups,
|
241 |
+
attn_num_head_channels=attn_num_head_channels,
|
242 |
+
)
|
243 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
244 |
+
|
245 |
+
|
246 |
+
class UNetMidBlock2D(nn.Module):
|
247 |
+
def __init__(
|
248 |
+
self,
|
249 |
+
in_channels: int,
|
250 |
+
temb_channels: int,
|
251 |
+
dropout: float = 0.0,
|
252 |
+
num_layers: int = 1,
|
253 |
+
resnet_eps: float = 1e-6,
|
254 |
+
resnet_time_scale_shift: str = "default",
|
255 |
+
resnet_act_fn: str = "swish",
|
256 |
+
resnet_groups: int = 32,
|
257 |
+
resnet_pre_norm: bool = True,
|
258 |
+
attn_num_head_channels=1,
|
259 |
+
attention_type="default",
|
260 |
+
output_scale_factor=1.0,
|
261 |
+
):
|
262 |
+
super().__init__()
|
263 |
+
|
264 |
+
self.attention_type = attention_type
|
265 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(
|
266 |
+
in_channels // 4, 32)
|
267 |
+
|
268 |
+
# there is always at least one resnet
|
269 |
+
resnets = [
|
270 |
+
ResnetBlock2D(
|
271 |
+
in_channels=in_channels,
|
272 |
+
out_channels=in_channels,
|
273 |
+
temb_channels=temb_channels,
|
274 |
+
eps=resnet_eps,
|
275 |
+
groups=resnet_groups,
|
276 |
+
dropout=dropout,
|
277 |
+
time_embedding_norm=resnet_time_scale_shift,
|
278 |
+
non_linearity=resnet_act_fn,
|
279 |
+
output_scale_factor=output_scale_factor,
|
280 |
+
pre_norm=resnet_pre_norm,
|
281 |
+
)
|
282 |
+
]
|
283 |
+
attentions = []
|
284 |
+
|
285 |
+
for _ in range(num_layers):
|
286 |
+
attentions.append(
|
287 |
+
AttentionBlock(
|
288 |
+
in_channels,
|
289 |
+
num_head_channels=attn_num_head_channels,
|
290 |
+
rescale_output_factor=output_scale_factor,
|
291 |
+
eps=resnet_eps,
|
292 |
+
norm_num_groups=resnet_groups,
|
293 |
+
)
|
294 |
+
)
|
295 |
+
resnets.append(
|
296 |
+
ResnetBlock2D(
|
297 |
+
in_channels=in_channels,
|
298 |
+
out_channels=in_channels,
|
299 |
+
temb_channels=temb_channels,
|
300 |
+
eps=resnet_eps,
|
301 |
+
groups=resnet_groups,
|
302 |
+
dropout=dropout,
|
303 |
+
time_embedding_norm=resnet_time_scale_shift,
|
304 |
+
non_linearity=resnet_act_fn,
|
305 |
+
output_scale_factor=output_scale_factor,
|
306 |
+
pre_norm=resnet_pre_norm,
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
self.attentions = nn.ModuleList(attentions)
|
311 |
+
self.resnets = nn.ModuleList(resnets)
|
312 |
+
|
313 |
+
def forward(self, hidden_states, temb=None, encoder_states=None):
|
314 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
315 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
316 |
+
if self.attention_type == "default":
|
317 |
+
hidden_states = attn(hidden_states)
|
318 |
+
else:
|
319 |
+
hidden_states = attn(hidden_states, encoder_states)
|
320 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
321 |
+
|
322 |
+
return hidden_states
|
323 |
+
|
324 |
+
|
325 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
in_channels: int,
|
329 |
+
temb_channels: int,
|
330 |
+
dropout: float = 0.0,
|
331 |
+
num_layers: int = 1,
|
332 |
+
resnet_eps: float = 1e-6,
|
333 |
+
resnet_time_scale_shift: str = "default",
|
334 |
+
resnet_act_fn: str = "swish",
|
335 |
+
resnet_groups: int = 32,
|
336 |
+
resnet_pre_norm: bool = True,
|
337 |
+
attn_num_head_channels=1,
|
338 |
+
attention_type="default",
|
339 |
+
output_scale_factor=1.0,
|
340 |
+
cross_attention_dim=1280,
|
341 |
+
dual_cross_attention=False,
|
342 |
+
use_linear_projection=False,
|
343 |
+
):
|
344 |
+
super().__init__()
|
345 |
+
|
346 |
+
self.attention_type = attention_type
|
347 |
+
self.attn_num_head_channels = attn_num_head_channels
|
348 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(
|
349 |
+
in_channels // 4, 32)
|
350 |
+
|
351 |
+
# there is always at least one resnet
|
352 |
+
resnets = [
|
353 |
+
ResnetBlock2D(
|
354 |
+
in_channels=in_channels,
|
355 |
+
out_channels=in_channels,
|
356 |
+
temb_channels=temb_channels,
|
357 |
+
eps=resnet_eps,
|
358 |
+
groups=resnet_groups,
|
359 |
+
dropout=dropout,
|
360 |
+
time_embedding_norm=resnet_time_scale_shift,
|
361 |
+
non_linearity=resnet_act_fn,
|
362 |
+
output_scale_factor=output_scale_factor,
|
363 |
+
pre_norm=resnet_pre_norm,
|
364 |
+
)
|
365 |
+
]
|
366 |
+
attentions = []
|
367 |
+
|
368 |
+
for _ in range(num_layers):
|
369 |
+
if not dual_cross_attention:
|
370 |
+
attentions.append(
|
371 |
+
Transformer2DModel(
|
372 |
+
attn_num_head_channels,
|
373 |
+
in_channels // attn_num_head_channels,
|
374 |
+
in_channels=in_channels,
|
375 |
+
num_layers=1,
|
376 |
+
cross_attention_dim=cross_attention_dim,
|
377 |
+
norm_num_groups=resnet_groups,
|
378 |
+
use_linear_projection=use_linear_projection,
|
379 |
+
)
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
attentions.append(
|
383 |
+
DualTransformer2DModel(
|
384 |
+
attn_num_head_channels,
|
385 |
+
in_channels // attn_num_head_channels,
|
386 |
+
in_channels=in_channels,
|
387 |
+
num_layers=1,
|
388 |
+
cross_attention_dim=cross_attention_dim,
|
389 |
+
norm_num_groups=resnet_groups,
|
390 |
+
)
|
391 |
+
)
|
392 |
+
resnets.append(
|
393 |
+
ResnetBlock2D(
|
394 |
+
in_channels=in_channels,
|
395 |
+
out_channels=in_channels,
|
396 |
+
temb_channels=temb_channels,
|
397 |
+
eps=resnet_eps,
|
398 |
+
groups=resnet_groups,
|
399 |
+
dropout=dropout,
|
400 |
+
time_embedding_norm=resnet_time_scale_shift,
|
401 |
+
non_linearity=resnet_act_fn,
|
402 |
+
output_scale_factor=output_scale_factor,
|
403 |
+
pre_norm=resnet_pre_norm,
|
404 |
+
)
|
405 |
+
)
|
406 |
+
|
407 |
+
self.attentions = nn.ModuleList(attentions)
|
408 |
+
self.resnets = nn.ModuleList(resnets)
|
409 |
+
|
410 |
+
def set_attention_slice(self, slice_size):
|
411 |
+
head_dims = self.attn_num_head_channels
|
412 |
+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
413 |
+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
414 |
+
raise ValueError(
|
415 |
+
f"Make sure slice_size {slice_size} is a common divisor of "
|
416 |
+
f"the number of heads used in cross_attention: {head_dims}"
|
417 |
+
)
|
418 |
+
if slice_size is not None and slice_size > min(head_dims):
|
419 |
+
raise ValueError(
|
420 |
+
f"slice_size {slice_size} has to be smaller or equal to "
|
421 |
+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
422 |
+
)
|
423 |
+
|
424 |
+
for attn in self.attentions:
|
425 |
+
attn._set_attention_slice(slice_size)
|
426 |
+
|
427 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
428 |
+
for attn in self.attentions:
|
429 |
+
attn._set_use_memory_efficient_attention_xformers(
|
430 |
+
use_memory_efficient_attention_xformers)
|
431 |
+
|
432 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
|
433 |
+
text_format_dict={}):
|
434 |
+
hidden_states, _ = self.resnets[0](hidden_states, temb)
|
435 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
436 |
+
hidden_states = attn(hidden_states, encoder_hidden_states,
|
437 |
+
text_format_dict).sample
|
438 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
439 |
+
|
440 |
+
return hidden_states
|
441 |
+
|
442 |
+
|
443 |
+
class AttnDownBlock2D(nn.Module):
|
444 |
+
def __init__(
|
445 |
+
self,
|
446 |
+
in_channels: int,
|
447 |
+
out_channels: int,
|
448 |
+
temb_channels: int,
|
449 |
+
dropout: float = 0.0,
|
450 |
+
num_layers: int = 1,
|
451 |
+
resnet_eps: float = 1e-6,
|
452 |
+
resnet_time_scale_shift: str = "default",
|
453 |
+
resnet_act_fn: str = "swish",
|
454 |
+
resnet_groups: int = 32,
|
455 |
+
resnet_pre_norm: bool = True,
|
456 |
+
attn_num_head_channels=1,
|
457 |
+
attention_type="default",
|
458 |
+
output_scale_factor=1.0,
|
459 |
+
downsample_padding=1,
|
460 |
+
add_downsample=True,
|
461 |
+
):
|
462 |
+
super().__init__()
|
463 |
+
resnets = []
|
464 |
+
attentions = []
|
465 |
+
|
466 |
+
self.attention_type = attention_type
|
467 |
+
|
468 |
+
for i in range(num_layers):
|
469 |
+
in_channels = in_channels if i == 0 else out_channels
|
470 |
+
resnets.append(
|
471 |
+
ResnetBlock2D(
|
472 |
+
in_channels=in_channels,
|
473 |
+
out_channels=out_channels,
|
474 |
+
temb_channels=temb_channels,
|
475 |
+
eps=resnet_eps,
|
476 |
+
groups=resnet_groups,
|
477 |
+
dropout=dropout,
|
478 |
+
time_embedding_norm=resnet_time_scale_shift,
|
479 |
+
non_linearity=resnet_act_fn,
|
480 |
+
output_scale_factor=output_scale_factor,
|
481 |
+
pre_norm=resnet_pre_norm,
|
482 |
+
)
|
483 |
+
)
|
484 |
+
attentions.append(
|
485 |
+
AttentionBlock(
|
486 |
+
out_channels,
|
487 |
+
num_head_channels=attn_num_head_channels,
|
488 |
+
rescale_output_factor=output_scale_factor,
|
489 |
+
eps=resnet_eps,
|
490 |
+
norm_num_groups=resnet_groups,
|
491 |
+
)
|
492 |
+
)
|
493 |
+
|
494 |
+
self.attentions = nn.ModuleList(attentions)
|
495 |
+
self.resnets = nn.ModuleList(resnets)
|
496 |
+
|
497 |
+
if add_downsample:
|
498 |
+
self.downsamplers = nn.ModuleList(
|
499 |
+
[
|
500 |
+
Downsample2D(
|
501 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
502 |
+
)
|
503 |
+
]
|
504 |
+
)
|
505 |
+
else:
|
506 |
+
self.downsamplers = None
|
507 |
+
|
508 |
+
def forward(self, hidden_states, temb=None):
|
509 |
+
output_states = ()
|
510 |
+
|
511 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
512 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
513 |
+
hidden_states = attn(hidden_states)
|
514 |
+
output_states += (hidden_states,)
|
515 |
+
|
516 |
+
if self.downsamplers is not None:
|
517 |
+
for downsampler in self.downsamplers:
|
518 |
+
hidden_states = downsampler(hidden_states)
|
519 |
+
|
520 |
+
output_states += (hidden_states,)
|
521 |
+
|
522 |
+
return hidden_states, output_states
|
523 |
+
|
524 |
+
|
525 |
+
class CrossAttnDownBlock2D(nn.Module):
|
526 |
+
def __init__(
|
527 |
+
self,
|
528 |
+
in_channels: int,
|
529 |
+
out_channels: int,
|
530 |
+
temb_channels: int,
|
531 |
+
dropout: float = 0.0,
|
532 |
+
num_layers: int = 1,
|
533 |
+
resnet_eps: float = 1e-6,
|
534 |
+
resnet_time_scale_shift: str = "default",
|
535 |
+
resnet_act_fn: str = "swish",
|
536 |
+
resnet_groups: int = 32,
|
537 |
+
resnet_pre_norm: bool = True,
|
538 |
+
attn_num_head_channels=1,
|
539 |
+
cross_attention_dim=1280,
|
540 |
+
attention_type="default",
|
541 |
+
output_scale_factor=1.0,
|
542 |
+
downsample_padding=1,
|
543 |
+
add_downsample=True,
|
544 |
+
dual_cross_attention=False,
|
545 |
+
use_linear_projection=False,
|
546 |
+
only_cross_attention=False,
|
547 |
+
):
|
548 |
+
super().__init__()
|
549 |
+
resnets = []
|
550 |
+
attentions = []
|
551 |
+
|
552 |
+
self.attention_type = attention_type
|
553 |
+
self.attn_num_head_channels = attn_num_head_channels
|
554 |
+
|
555 |
+
for i in range(num_layers):
|
556 |
+
in_channels = in_channels if i == 0 else out_channels
|
557 |
+
resnets.append(
|
558 |
+
ResnetBlock2D(
|
559 |
+
in_channels=in_channels,
|
560 |
+
out_channels=out_channels,
|
561 |
+
temb_channels=temb_channels,
|
562 |
+
eps=resnet_eps,
|
563 |
+
groups=resnet_groups,
|
564 |
+
dropout=dropout,
|
565 |
+
time_embedding_norm=resnet_time_scale_shift,
|
566 |
+
non_linearity=resnet_act_fn,
|
567 |
+
output_scale_factor=output_scale_factor,
|
568 |
+
pre_norm=resnet_pre_norm,
|
569 |
+
)
|
570 |
+
)
|
571 |
+
if not dual_cross_attention:
|
572 |
+
attentions.append(
|
573 |
+
Transformer2DModel(
|
574 |
+
attn_num_head_channels,
|
575 |
+
out_channels // attn_num_head_channels,
|
576 |
+
in_channels=out_channels,
|
577 |
+
num_layers=1,
|
578 |
+
cross_attention_dim=cross_attention_dim,
|
579 |
+
norm_num_groups=resnet_groups,
|
580 |
+
use_linear_projection=use_linear_projection,
|
581 |
+
only_cross_attention=only_cross_attention,
|
582 |
+
)
|
583 |
+
)
|
584 |
+
else:
|
585 |
+
attentions.append(
|
586 |
+
DualTransformer2DModel(
|
587 |
+
attn_num_head_channels,
|
588 |
+
out_channels // attn_num_head_channels,
|
589 |
+
in_channels=out_channels,
|
590 |
+
num_layers=1,
|
591 |
+
cross_attention_dim=cross_attention_dim,
|
592 |
+
norm_num_groups=resnet_groups,
|
593 |
+
)
|
594 |
+
)
|
595 |
+
self.attentions = nn.ModuleList(attentions)
|
596 |
+
self.resnets = nn.ModuleList(resnets)
|
597 |
+
|
598 |
+
if add_downsample:
|
599 |
+
self.downsamplers = nn.ModuleList(
|
600 |
+
[
|
601 |
+
Downsample2D(
|
602 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
603 |
+
)
|
604 |
+
]
|
605 |
+
)
|
606 |
+
else:
|
607 |
+
self.downsamplers = None
|
608 |
+
|
609 |
+
self.gradient_checkpointing = False
|
610 |
+
|
611 |
+
def set_attention_slice(self, slice_size):
|
612 |
+
head_dims = self.attn_num_head_channels
|
613 |
+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
614 |
+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
615 |
+
raise ValueError(
|
616 |
+
f"Make sure slice_size {slice_size} is a common divisor of "
|
617 |
+
f"the number of heads used in cross_attention: {head_dims}"
|
618 |
+
)
|
619 |
+
if slice_size is not None and slice_size > min(head_dims):
|
620 |
+
raise ValueError(
|
621 |
+
f"slice_size {slice_size} has to be smaller or equal to "
|
622 |
+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
623 |
+
)
|
624 |
+
|
625 |
+
for attn in self.attentions:
|
626 |
+
attn._set_attention_slice(slice_size)
|
627 |
+
|
628 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
629 |
+
for attn in self.attentions:
|
630 |
+
attn._set_use_memory_efficient_attention_xformers(
|
631 |
+
use_memory_efficient_attention_xformers)
|
632 |
+
|
633 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
|
634 |
+
text_format_dict={}):
|
635 |
+
output_states = ()
|
636 |
+
|
637 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
638 |
+
if self.training and self.gradient_checkpointing:
|
639 |
+
|
640 |
+
def create_custom_forward(module, return_dict=None):
|
641 |
+
def custom_forward(*inputs):
|
642 |
+
if return_dict is not None:
|
643 |
+
return module(*inputs, return_dict=return_dict)
|
644 |
+
else:
|
645 |
+
return module(*inputs)
|
646 |
+
|
647 |
+
return custom_forward
|
648 |
+
|
649 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
650 |
+
create_custom_forward(resnet), hidden_states, temb)
|
651 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
652 |
+
create_custom_forward(
|
653 |
+
attn, return_dict=False), hidden_states, encoder_hidden_states,
|
654 |
+
text_format_dict
|
655 |
+
)[0]
|
656 |
+
else:
|
657 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
658 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
|
659 |
+
text_format_dict=text_format_dict).sample
|
660 |
+
|
661 |
+
output_states += (hidden_states,)
|
662 |
+
|
663 |
+
if self.downsamplers is not None:
|
664 |
+
for downsampler in self.downsamplers:
|
665 |
+
hidden_states = downsampler(hidden_states)
|
666 |
+
|
667 |
+
output_states += (hidden_states,)
|
668 |
+
|
669 |
+
return hidden_states, output_states
|
670 |
+
|
671 |
+
|
672 |
+
class DownBlock2D(nn.Module):
|
673 |
+
def __init__(
|
674 |
+
self,
|
675 |
+
in_channels: int,
|
676 |
+
out_channels: int,
|
677 |
+
temb_channels: int,
|
678 |
+
dropout: float = 0.0,
|
679 |
+
num_layers: int = 1,
|
680 |
+
resnet_eps: float = 1e-6,
|
681 |
+
resnet_time_scale_shift: str = "default",
|
682 |
+
resnet_act_fn: str = "swish",
|
683 |
+
resnet_groups: int = 32,
|
684 |
+
resnet_pre_norm: bool = True,
|
685 |
+
output_scale_factor=1.0,
|
686 |
+
add_downsample=True,
|
687 |
+
downsample_padding=1,
|
688 |
+
):
|
689 |
+
super().__init__()
|
690 |
+
resnets = []
|
691 |
+
|
692 |
+
for i in range(num_layers):
|
693 |
+
in_channels = in_channels if i == 0 else out_channels
|
694 |
+
resnets.append(
|
695 |
+
ResnetBlock2D(
|
696 |
+
in_channels=in_channels,
|
697 |
+
out_channels=out_channels,
|
698 |
+
temb_channels=temb_channels,
|
699 |
+
eps=resnet_eps,
|
700 |
+
groups=resnet_groups,
|
701 |
+
dropout=dropout,
|
702 |
+
time_embedding_norm=resnet_time_scale_shift,
|
703 |
+
non_linearity=resnet_act_fn,
|
704 |
+
output_scale_factor=output_scale_factor,
|
705 |
+
pre_norm=resnet_pre_norm,
|
706 |
+
)
|
707 |
+
)
|
708 |
+
|
709 |
+
self.resnets = nn.ModuleList(resnets)
|
710 |
+
|
711 |
+
if add_downsample:
|
712 |
+
self.downsamplers = nn.ModuleList(
|
713 |
+
[
|
714 |
+
Downsample2D(
|
715 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
716 |
+
)
|
717 |
+
]
|
718 |
+
)
|
719 |
+
else:
|
720 |
+
self.downsamplers = None
|
721 |
+
|
722 |
+
self.gradient_checkpointing = False
|
723 |
+
|
724 |
+
def forward(self, hidden_states, temb=None):
|
725 |
+
output_states = ()
|
726 |
+
|
727 |
+
for resnet in self.resnets:
|
728 |
+
if self.training and self.gradient_checkpointing:
|
729 |
+
|
730 |
+
def create_custom_forward(module):
|
731 |
+
def custom_forward(*inputs):
|
732 |
+
return module(*inputs)
|
733 |
+
|
734 |
+
return custom_forward
|
735 |
+
|
736 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
737 |
+
create_custom_forward(resnet), hidden_states, temb)
|
738 |
+
else:
|
739 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
740 |
+
|
741 |
+
output_states += (hidden_states,)
|
742 |
+
|
743 |
+
if self.downsamplers is not None:
|
744 |
+
for downsampler in self.downsamplers:
|
745 |
+
hidden_states = downsampler(hidden_states)
|
746 |
+
|
747 |
+
output_states += (hidden_states,)
|
748 |
+
|
749 |
+
return hidden_states, output_states
|
750 |
+
|
751 |
+
|
752 |
+
class DownEncoderBlock2D(nn.Module):
|
753 |
+
def __init__(
|
754 |
+
self,
|
755 |
+
in_channels: int,
|
756 |
+
out_channels: int,
|
757 |
+
dropout: float = 0.0,
|
758 |
+
num_layers: int = 1,
|
759 |
+
resnet_eps: float = 1e-6,
|
760 |
+
resnet_time_scale_shift: str = "default",
|
761 |
+
resnet_act_fn: str = "swish",
|
762 |
+
resnet_groups: int = 32,
|
763 |
+
resnet_pre_norm: bool = True,
|
764 |
+
output_scale_factor=1.0,
|
765 |
+
add_downsample=True,
|
766 |
+
downsample_padding=1,
|
767 |
+
):
|
768 |
+
super().__init__()
|
769 |
+
resnets = []
|
770 |
+
|
771 |
+
for i in range(num_layers):
|
772 |
+
in_channels = in_channels if i == 0 else out_channels
|
773 |
+
resnets.append(
|
774 |
+
ResnetBlock2D(
|
775 |
+
in_channels=in_channels,
|
776 |
+
out_channels=out_channels,
|
777 |
+
temb_channels=None,
|
778 |
+
eps=resnet_eps,
|
779 |
+
groups=resnet_groups,
|
780 |
+
dropout=dropout,
|
781 |
+
time_embedding_norm=resnet_time_scale_shift,
|
782 |
+
non_linearity=resnet_act_fn,
|
783 |
+
output_scale_factor=output_scale_factor,
|
784 |
+
pre_norm=resnet_pre_norm,
|
785 |
+
)
|
786 |
+
)
|
787 |
+
|
788 |
+
self.resnets = nn.ModuleList(resnets)
|
789 |
+
|
790 |
+
if add_downsample:
|
791 |
+
self.downsamplers = nn.ModuleList(
|
792 |
+
[
|
793 |
+
Downsample2D(
|
794 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
795 |
+
)
|
796 |
+
]
|
797 |
+
)
|
798 |
+
else:
|
799 |
+
self.downsamplers = None
|
800 |
+
|
801 |
+
def forward(self, hidden_states):
|
802 |
+
for resnet in self.resnets:
|
803 |
+
hidden_states, _ = resnet(hidden_states, temb=None)
|
804 |
+
|
805 |
+
if self.downsamplers is not None:
|
806 |
+
for downsampler in self.downsamplers:
|
807 |
+
hidden_states = downsampler(hidden_states)
|
808 |
+
|
809 |
+
return hidden_states
|
810 |
+
|
811 |
+
|
812 |
+
class AttnDownEncoderBlock2D(nn.Module):
|
813 |
+
def __init__(
|
814 |
+
self,
|
815 |
+
in_channels: int,
|
816 |
+
out_channels: int,
|
817 |
+
dropout: float = 0.0,
|
818 |
+
num_layers: int = 1,
|
819 |
+
resnet_eps: float = 1e-6,
|
820 |
+
resnet_time_scale_shift: str = "default",
|
821 |
+
resnet_act_fn: str = "swish",
|
822 |
+
resnet_groups: int = 32,
|
823 |
+
resnet_pre_norm: bool = True,
|
824 |
+
attn_num_head_channels=1,
|
825 |
+
output_scale_factor=1.0,
|
826 |
+
add_downsample=True,
|
827 |
+
downsample_padding=1,
|
828 |
+
):
|
829 |
+
super().__init__()
|
830 |
+
resnets = []
|
831 |
+
attentions = []
|
832 |
+
|
833 |
+
for i in range(num_layers):
|
834 |
+
in_channels = in_channels if i == 0 else out_channels
|
835 |
+
resnets.append(
|
836 |
+
ResnetBlock2D(
|
837 |
+
in_channels=in_channels,
|
838 |
+
out_channels=out_channels,
|
839 |
+
temb_channels=None,
|
840 |
+
eps=resnet_eps,
|
841 |
+
groups=resnet_groups,
|
842 |
+
dropout=dropout,
|
843 |
+
time_embedding_norm=resnet_time_scale_shift,
|
844 |
+
non_linearity=resnet_act_fn,
|
845 |
+
output_scale_factor=output_scale_factor,
|
846 |
+
pre_norm=resnet_pre_norm,
|
847 |
+
)
|
848 |
+
)
|
849 |
+
attentions.append(
|
850 |
+
AttentionBlock(
|
851 |
+
out_channels,
|
852 |
+
num_head_channels=attn_num_head_channels,
|
853 |
+
rescale_output_factor=output_scale_factor,
|
854 |
+
eps=resnet_eps,
|
855 |
+
norm_num_groups=resnet_groups,
|
856 |
+
)
|
857 |
+
)
|
858 |
+
|
859 |
+
self.attentions = nn.ModuleList(attentions)
|
860 |
+
self.resnets = nn.ModuleList(resnets)
|
861 |
+
|
862 |
+
if add_downsample:
|
863 |
+
self.downsamplers = nn.ModuleList(
|
864 |
+
[
|
865 |
+
Downsample2D(
|
866 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
867 |
+
)
|
868 |
+
]
|
869 |
+
)
|
870 |
+
else:
|
871 |
+
self.downsamplers = None
|
872 |
+
|
873 |
+
def forward(self, hidden_states):
|
874 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
875 |
+
hidden_states, _ = resnet(hidden_states, temb=None)
|
876 |
+
hidden_states = attn(hidden_states)
|
877 |
+
|
878 |
+
if self.downsamplers is not None:
|
879 |
+
for downsampler in self.downsamplers:
|
880 |
+
hidden_states = downsampler(hidden_states)
|
881 |
+
|
882 |
+
return hidden_states
|
883 |
+
|
884 |
+
|
885 |
+
class AttnSkipDownBlock2D(nn.Module):
|
886 |
+
def __init__(
|
887 |
+
self,
|
888 |
+
in_channels: int,
|
889 |
+
out_channels: int,
|
890 |
+
temb_channels: int,
|
891 |
+
dropout: float = 0.0,
|
892 |
+
num_layers: int = 1,
|
893 |
+
resnet_eps: float = 1e-6,
|
894 |
+
resnet_time_scale_shift: str = "default",
|
895 |
+
resnet_act_fn: str = "swish",
|
896 |
+
resnet_pre_norm: bool = True,
|
897 |
+
attn_num_head_channels=1,
|
898 |
+
attention_type="default",
|
899 |
+
output_scale_factor=np.sqrt(2.0),
|
900 |
+
downsample_padding=1,
|
901 |
+
add_downsample=True,
|
902 |
+
):
|
903 |
+
super().__init__()
|
904 |
+
self.attentions = nn.ModuleList([])
|
905 |
+
self.resnets = nn.ModuleList([])
|
906 |
+
|
907 |
+
self.attention_type = attention_type
|
908 |
+
|
909 |
+
for i in range(num_layers):
|
910 |
+
in_channels = in_channels if i == 0 else out_channels
|
911 |
+
self.resnets.append(
|
912 |
+
ResnetBlock2D(
|
913 |
+
in_channels=in_channels,
|
914 |
+
out_channels=out_channels,
|
915 |
+
temb_channels=temb_channels,
|
916 |
+
eps=resnet_eps,
|
917 |
+
groups=min(in_channels // 4, 32),
|
918 |
+
groups_out=min(out_channels // 4, 32),
|
919 |
+
dropout=dropout,
|
920 |
+
time_embedding_norm=resnet_time_scale_shift,
|
921 |
+
non_linearity=resnet_act_fn,
|
922 |
+
output_scale_factor=output_scale_factor,
|
923 |
+
pre_norm=resnet_pre_norm,
|
924 |
+
)
|
925 |
+
)
|
926 |
+
self.attentions.append(
|
927 |
+
AttentionBlock(
|
928 |
+
out_channels,
|
929 |
+
num_head_channels=attn_num_head_channels,
|
930 |
+
rescale_output_factor=output_scale_factor,
|
931 |
+
eps=resnet_eps,
|
932 |
+
)
|
933 |
+
)
|
934 |
+
|
935 |
+
if add_downsample:
|
936 |
+
self.resnet_down = ResnetBlock2D(
|
937 |
+
in_channels=out_channels,
|
938 |
+
out_channels=out_channels,
|
939 |
+
temb_channels=temb_channels,
|
940 |
+
eps=resnet_eps,
|
941 |
+
groups=min(out_channels // 4, 32),
|
942 |
+
dropout=dropout,
|
943 |
+
time_embedding_norm=resnet_time_scale_shift,
|
944 |
+
non_linearity=resnet_act_fn,
|
945 |
+
output_scale_factor=output_scale_factor,
|
946 |
+
pre_norm=resnet_pre_norm,
|
947 |
+
use_in_shortcut=True,
|
948 |
+
down=True,
|
949 |
+
kernel="fir",
|
950 |
+
)
|
951 |
+
self.downsamplers = nn.ModuleList(
|
952 |
+
[FirDownsample2D(out_channels, out_channels=out_channels)])
|
953 |
+
self.skip_conv = nn.Conv2d(
|
954 |
+
3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
955 |
+
else:
|
956 |
+
self.resnet_down = None
|
957 |
+
self.downsamplers = None
|
958 |
+
self.skip_conv = None
|
959 |
+
|
960 |
+
def forward(self, hidden_states, temb=None, skip_sample=None):
|
961 |
+
output_states = ()
|
962 |
+
|
963 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
964 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
965 |
+
hidden_states = attn(hidden_states)
|
966 |
+
output_states += (hidden_states,)
|
967 |
+
|
968 |
+
if self.downsamplers is not None:
|
969 |
+
hidden_states = self.resnet_down(hidden_states, temb)
|
970 |
+
for downsampler in self.downsamplers:
|
971 |
+
skip_sample = downsampler(skip_sample)
|
972 |
+
|
973 |
+
hidden_states = self.skip_conv(skip_sample) + hidden_states
|
974 |
+
|
975 |
+
output_states += (hidden_states,)
|
976 |
+
|
977 |
+
return hidden_states, output_states, skip_sample
|
978 |
+
|
979 |
+
|
980 |
+
class SkipDownBlock2D(nn.Module):
|
981 |
+
def __init__(
|
982 |
+
self,
|
983 |
+
in_channels: int,
|
984 |
+
out_channels: int,
|
985 |
+
temb_channels: int,
|
986 |
+
dropout: float = 0.0,
|
987 |
+
num_layers: int = 1,
|
988 |
+
resnet_eps: float = 1e-6,
|
989 |
+
resnet_time_scale_shift: str = "default",
|
990 |
+
resnet_act_fn: str = "swish",
|
991 |
+
resnet_pre_norm: bool = True,
|
992 |
+
output_scale_factor=np.sqrt(2.0),
|
993 |
+
add_downsample=True,
|
994 |
+
downsample_padding=1,
|
995 |
+
):
|
996 |
+
super().__init__()
|
997 |
+
self.resnets = nn.ModuleList([])
|
998 |
+
|
999 |
+
for i in range(num_layers):
|
1000 |
+
in_channels = in_channels if i == 0 else out_channels
|
1001 |
+
self.resnets.append(
|
1002 |
+
ResnetBlock2D(
|
1003 |
+
in_channels=in_channels,
|
1004 |
+
out_channels=out_channels,
|
1005 |
+
temb_channels=temb_channels,
|
1006 |
+
eps=resnet_eps,
|
1007 |
+
groups=min(in_channels // 4, 32),
|
1008 |
+
groups_out=min(out_channels // 4, 32),
|
1009 |
+
dropout=dropout,
|
1010 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1011 |
+
non_linearity=resnet_act_fn,
|
1012 |
+
output_scale_factor=output_scale_factor,
|
1013 |
+
pre_norm=resnet_pre_norm,
|
1014 |
+
)
|
1015 |
+
)
|
1016 |
+
|
1017 |
+
if add_downsample:
|
1018 |
+
self.resnet_down = ResnetBlock2D(
|
1019 |
+
in_channels=out_channels,
|
1020 |
+
out_channels=out_channels,
|
1021 |
+
temb_channels=temb_channels,
|
1022 |
+
eps=resnet_eps,
|
1023 |
+
groups=min(out_channels // 4, 32),
|
1024 |
+
dropout=dropout,
|
1025 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1026 |
+
non_linearity=resnet_act_fn,
|
1027 |
+
output_scale_factor=output_scale_factor,
|
1028 |
+
pre_norm=resnet_pre_norm,
|
1029 |
+
use_in_shortcut=True,
|
1030 |
+
down=True,
|
1031 |
+
kernel="fir",
|
1032 |
+
)
|
1033 |
+
self.downsamplers = nn.ModuleList(
|
1034 |
+
[FirDownsample2D(out_channels, out_channels=out_channels)])
|
1035 |
+
self.skip_conv = nn.Conv2d(
|
1036 |
+
3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
1037 |
+
else:
|
1038 |
+
self.resnet_down = None
|
1039 |
+
self.downsamplers = None
|
1040 |
+
self.skip_conv = None
|
1041 |
+
|
1042 |
+
def forward(self, hidden_states, temb=None, skip_sample=None):
|
1043 |
+
output_states = ()
|
1044 |
+
|
1045 |
+
for resnet in self.resnets:
|
1046 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
1047 |
+
output_states += (hidden_states,)
|
1048 |
+
|
1049 |
+
if self.downsamplers is not None:
|
1050 |
+
hidden_states = self.resnet_down(hidden_states, temb)
|
1051 |
+
for downsampler in self.downsamplers:
|
1052 |
+
skip_sample = downsampler(skip_sample)
|
1053 |
+
|
1054 |
+
hidden_states = self.skip_conv(skip_sample) + hidden_states
|
1055 |
+
|
1056 |
+
output_states += (hidden_states,)
|
1057 |
+
|
1058 |
+
return hidden_states, output_states, skip_sample
|
1059 |
+
|
1060 |
+
|
1061 |
+
class AttnUpBlock2D(nn.Module):
|
1062 |
+
def __init__(
|
1063 |
+
self,
|
1064 |
+
in_channels: int,
|
1065 |
+
prev_output_channel: int,
|
1066 |
+
out_channels: int,
|
1067 |
+
temb_channels: int,
|
1068 |
+
dropout: float = 0.0,
|
1069 |
+
num_layers: int = 1,
|
1070 |
+
resnet_eps: float = 1e-6,
|
1071 |
+
resnet_time_scale_shift: str = "default",
|
1072 |
+
resnet_act_fn: str = "swish",
|
1073 |
+
resnet_groups: int = 32,
|
1074 |
+
resnet_pre_norm: bool = True,
|
1075 |
+
attention_type="default",
|
1076 |
+
attn_num_head_channels=1,
|
1077 |
+
output_scale_factor=1.0,
|
1078 |
+
add_upsample=True,
|
1079 |
+
):
|
1080 |
+
super().__init__()
|
1081 |
+
resnets = []
|
1082 |
+
attentions = []
|
1083 |
+
|
1084 |
+
self.attention_type = attention_type
|
1085 |
+
|
1086 |
+
for i in range(num_layers):
|
1087 |
+
res_skip_channels = in_channels if (
|
1088 |
+
i == num_layers - 1) else out_channels
|
1089 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1090 |
+
|
1091 |
+
resnets.append(
|
1092 |
+
ResnetBlock2D(
|
1093 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1094 |
+
out_channels=out_channels,
|
1095 |
+
temb_channels=temb_channels,
|
1096 |
+
eps=resnet_eps,
|
1097 |
+
groups=resnet_groups,
|
1098 |
+
dropout=dropout,
|
1099 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1100 |
+
non_linearity=resnet_act_fn,
|
1101 |
+
output_scale_factor=output_scale_factor,
|
1102 |
+
pre_norm=resnet_pre_norm,
|
1103 |
+
)
|
1104 |
+
)
|
1105 |
+
attentions.append(
|
1106 |
+
AttentionBlock(
|
1107 |
+
out_channels,
|
1108 |
+
num_head_channels=attn_num_head_channels,
|
1109 |
+
rescale_output_factor=output_scale_factor,
|
1110 |
+
eps=resnet_eps,
|
1111 |
+
norm_num_groups=resnet_groups,
|
1112 |
+
)
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
self.attentions = nn.ModuleList(attentions)
|
1116 |
+
self.resnets = nn.ModuleList(resnets)
|
1117 |
+
|
1118 |
+
if add_upsample:
|
1119 |
+
self.upsamplers = nn.ModuleList(
|
1120 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1121 |
+
else:
|
1122 |
+
self.upsamplers = None
|
1123 |
+
|
1124 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
1125 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1126 |
+
# pop res hidden states
|
1127 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1128 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1129 |
+
hidden_states = torch.cat(
|
1130 |
+
[hidden_states, res_hidden_states], dim=1)
|
1131 |
+
|
1132 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
1133 |
+
hidden_states = attn(hidden_states)
|
1134 |
+
|
1135 |
+
if self.upsamplers is not None:
|
1136 |
+
for upsampler in self.upsamplers:
|
1137 |
+
hidden_states = upsampler(hidden_states)
|
1138 |
+
|
1139 |
+
return hidden_states
|
1140 |
+
|
1141 |
+
|
1142 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1143 |
+
def __init__(
|
1144 |
+
self,
|
1145 |
+
in_channels: int,
|
1146 |
+
out_channels: int,
|
1147 |
+
prev_output_channel: int,
|
1148 |
+
temb_channels: int,
|
1149 |
+
dropout: float = 0.0,
|
1150 |
+
num_layers: int = 1,
|
1151 |
+
resnet_eps: float = 1e-6,
|
1152 |
+
resnet_time_scale_shift: str = "default",
|
1153 |
+
resnet_act_fn: str = "swish",
|
1154 |
+
resnet_groups: int = 32,
|
1155 |
+
resnet_pre_norm: bool = True,
|
1156 |
+
attn_num_head_channels=1,
|
1157 |
+
cross_attention_dim=1280,
|
1158 |
+
attention_type="default",
|
1159 |
+
output_scale_factor=1.0,
|
1160 |
+
add_upsample=True,
|
1161 |
+
dual_cross_attention=False,
|
1162 |
+
use_linear_projection=False,
|
1163 |
+
only_cross_attention=False,
|
1164 |
+
):
|
1165 |
+
super().__init__()
|
1166 |
+
resnets = []
|
1167 |
+
attentions = []
|
1168 |
+
|
1169 |
+
self.attention_type = attention_type
|
1170 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1171 |
+
|
1172 |
+
for i in range(num_layers):
|
1173 |
+
res_skip_channels = in_channels if (
|
1174 |
+
i == num_layers - 1) else out_channels
|
1175 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1176 |
+
|
1177 |
+
resnets.append(
|
1178 |
+
ResnetBlock2D(
|
1179 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1180 |
+
out_channels=out_channels,
|
1181 |
+
temb_channels=temb_channels,
|
1182 |
+
eps=resnet_eps,
|
1183 |
+
groups=resnet_groups,
|
1184 |
+
dropout=dropout,
|
1185 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1186 |
+
non_linearity=resnet_act_fn,
|
1187 |
+
output_scale_factor=output_scale_factor,
|
1188 |
+
pre_norm=resnet_pre_norm,
|
1189 |
+
)
|
1190 |
+
)
|
1191 |
+
if not dual_cross_attention:
|
1192 |
+
attentions.append(
|
1193 |
+
Transformer2DModel(
|
1194 |
+
attn_num_head_channels,
|
1195 |
+
out_channels // attn_num_head_channels,
|
1196 |
+
in_channels=out_channels,
|
1197 |
+
num_layers=1,
|
1198 |
+
cross_attention_dim=cross_attention_dim,
|
1199 |
+
norm_num_groups=resnet_groups,
|
1200 |
+
use_linear_projection=use_linear_projection,
|
1201 |
+
only_cross_attention=only_cross_attention,
|
1202 |
+
)
|
1203 |
+
)
|
1204 |
+
else:
|
1205 |
+
attentions.append(
|
1206 |
+
DualTransformer2DModel(
|
1207 |
+
attn_num_head_channels,
|
1208 |
+
out_channels // attn_num_head_channels,
|
1209 |
+
in_channels=out_channels,
|
1210 |
+
num_layers=1,
|
1211 |
+
cross_attention_dim=cross_attention_dim,
|
1212 |
+
norm_num_groups=resnet_groups,
|
1213 |
+
)
|
1214 |
+
)
|
1215 |
+
self.attentions = nn.ModuleList(attentions)
|
1216 |
+
self.resnets = nn.ModuleList(resnets)
|
1217 |
+
|
1218 |
+
if add_upsample:
|
1219 |
+
self.upsamplers = nn.ModuleList(
|
1220 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1221 |
+
else:
|
1222 |
+
self.upsamplers = None
|
1223 |
+
|
1224 |
+
self.gradient_checkpointing = False
|
1225 |
+
|
1226 |
+
def set_attention_slice(self, slice_size):
|
1227 |
+
head_dims = self.attn_num_head_channels
|
1228 |
+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
1229 |
+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
1230 |
+
raise ValueError(
|
1231 |
+
f"Make sure slice_size {slice_size} is a common divisor of "
|
1232 |
+
f"the number of heads used in cross_attention: {head_dims}"
|
1233 |
+
)
|
1234 |
+
if slice_size is not None and slice_size > min(head_dims):
|
1235 |
+
raise ValueError(
|
1236 |
+
f"slice_size {slice_size} has to be smaller or equal to "
|
1237 |
+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
1238 |
+
)
|
1239 |
+
|
1240 |
+
for attn in self.attentions:
|
1241 |
+
attn._set_attention_slice(slice_size)
|
1242 |
+
|
1243 |
+
self.gradient_checkpointing = False
|
1244 |
+
|
1245 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
1246 |
+
for attn in self.attentions:
|
1247 |
+
attn._set_use_memory_efficient_attention_xformers(
|
1248 |
+
use_memory_efficient_attention_xformers)
|
1249 |
+
|
1250 |
+
def forward(
|
1251 |
+
self,
|
1252 |
+
hidden_states,
|
1253 |
+
res_hidden_states_tuple,
|
1254 |
+
temb=None,
|
1255 |
+
encoder_hidden_states=None,
|
1256 |
+
upsample_size=None,
|
1257 |
+
text_format_dict={}
|
1258 |
+
):
|
1259 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1260 |
+
# pop res hidden states
|
1261 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1262 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1263 |
+
hidden_states = torch.cat(
|
1264 |
+
[hidden_states, res_hidden_states], dim=1)
|
1265 |
+
|
1266 |
+
if self.training and self.gradient_checkpointing:
|
1267 |
+
|
1268 |
+
def create_custom_forward(module, return_dict=None):
|
1269 |
+
def custom_forward(*inputs):
|
1270 |
+
if return_dict is not None:
|
1271 |
+
return module(*inputs, return_dict=return_dict)
|
1272 |
+
else:
|
1273 |
+
return module(*inputs)
|
1274 |
+
|
1275 |
+
return custom_forward
|
1276 |
+
|
1277 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1278 |
+
create_custom_forward(resnet), hidden_states, temb)
|
1279 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1280 |
+
create_custom_forward(
|
1281 |
+
attn, return_dict=False), hidden_states, encoder_hidden_states,
|
1282 |
+
text_format_dict
|
1283 |
+
)[0]
|
1284 |
+
else:
|
1285 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
1286 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
|
1287 |
+
text_format_dict=text_format_dict).sample
|
1288 |
+
|
1289 |
+
if self.upsamplers is not None:
|
1290 |
+
for upsampler in self.upsamplers:
|
1291 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1292 |
+
|
1293 |
+
return hidden_states
|
1294 |
+
|
1295 |
+
|
1296 |
+
class UpBlock2D(nn.Module):
|
1297 |
+
def __init__(
|
1298 |
+
self,
|
1299 |
+
in_channels: int,
|
1300 |
+
prev_output_channel: int,
|
1301 |
+
out_channels: int,
|
1302 |
+
temb_channels: int,
|
1303 |
+
dropout: float = 0.0,
|
1304 |
+
num_layers: int = 1,
|
1305 |
+
resnet_eps: float = 1e-6,
|
1306 |
+
resnet_time_scale_shift: str = "default",
|
1307 |
+
resnet_act_fn: str = "swish",
|
1308 |
+
resnet_groups: int = 32,
|
1309 |
+
resnet_pre_norm: bool = True,
|
1310 |
+
output_scale_factor=1.0,
|
1311 |
+
add_upsample=True,
|
1312 |
+
):
|
1313 |
+
super().__init__()
|
1314 |
+
resnets = []
|
1315 |
+
|
1316 |
+
for i in range(num_layers):
|
1317 |
+
res_skip_channels = in_channels if (
|
1318 |
+
i == num_layers - 1) else out_channels
|
1319 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1320 |
+
|
1321 |
+
resnets.append(
|
1322 |
+
ResnetBlock2D(
|
1323 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1324 |
+
out_channels=out_channels,
|
1325 |
+
temb_channels=temb_channels,
|
1326 |
+
eps=resnet_eps,
|
1327 |
+
groups=resnet_groups,
|
1328 |
+
dropout=dropout,
|
1329 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1330 |
+
non_linearity=resnet_act_fn,
|
1331 |
+
output_scale_factor=output_scale_factor,
|
1332 |
+
pre_norm=resnet_pre_norm,
|
1333 |
+
)
|
1334 |
+
)
|
1335 |
+
|
1336 |
+
self.resnets = nn.ModuleList(resnets)
|
1337 |
+
|
1338 |
+
if add_upsample:
|
1339 |
+
self.upsamplers = nn.ModuleList(
|
1340 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1341 |
+
else:
|
1342 |
+
self.upsamplers = None
|
1343 |
+
|
1344 |
+
self.gradient_checkpointing = False
|
1345 |
+
|
1346 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1347 |
+
for resnet in self.resnets:
|
1348 |
+
# pop res hidden states
|
1349 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1350 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1351 |
+
hidden_states = torch.cat(
|
1352 |
+
[hidden_states, res_hidden_states], dim=1)
|
1353 |
+
|
1354 |
+
if self.training and self.gradient_checkpointing:
|
1355 |
+
|
1356 |
+
def create_custom_forward(module):
|
1357 |
+
def custom_forward(*inputs):
|
1358 |
+
return module(*inputs)
|
1359 |
+
|
1360 |
+
return custom_forward
|
1361 |
+
|
1362 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1363 |
+
create_custom_forward(resnet), hidden_states, temb)
|
1364 |
+
else:
|
1365 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
1366 |
+
|
1367 |
+
if self.upsamplers is not None:
|
1368 |
+
for upsampler in self.upsamplers:
|
1369 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1370 |
+
|
1371 |
+
return hidden_states
|
1372 |
+
|
1373 |
+
|
1374 |
+
class UpDecoderBlock2D(nn.Module):
|
1375 |
+
def __init__(
|
1376 |
+
self,
|
1377 |
+
in_channels: int,
|
1378 |
+
out_channels: int,
|
1379 |
+
dropout: float = 0.0,
|
1380 |
+
num_layers: int = 1,
|
1381 |
+
resnet_eps: float = 1e-6,
|
1382 |
+
resnet_time_scale_shift: str = "default",
|
1383 |
+
resnet_act_fn: str = "swish",
|
1384 |
+
resnet_groups: int = 32,
|
1385 |
+
resnet_pre_norm: bool = True,
|
1386 |
+
output_scale_factor=1.0,
|
1387 |
+
add_upsample=True,
|
1388 |
+
):
|
1389 |
+
super().__init__()
|
1390 |
+
resnets = []
|
1391 |
+
|
1392 |
+
for i in range(num_layers):
|
1393 |
+
input_channels = in_channels if i == 0 else out_channels
|
1394 |
+
|
1395 |
+
resnets.append(
|
1396 |
+
ResnetBlock2D(
|
1397 |
+
in_channels=input_channels,
|
1398 |
+
out_channels=out_channels,
|
1399 |
+
temb_channels=None,
|
1400 |
+
eps=resnet_eps,
|
1401 |
+
groups=resnet_groups,
|
1402 |
+
dropout=dropout,
|
1403 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1404 |
+
non_linearity=resnet_act_fn,
|
1405 |
+
output_scale_factor=output_scale_factor,
|
1406 |
+
pre_norm=resnet_pre_norm,
|
1407 |
+
)
|
1408 |
+
)
|
1409 |
+
|
1410 |
+
self.resnets = nn.ModuleList(resnets)
|
1411 |
+
|
1412 |
+
if add_upsample:
|
1413 |
+
self.upsamplers = nn.ModuleList(
|
1414 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1415 |
+
else:
|
1416 |
+
self.upsamplers = None
|
1417 |
+
|
1418 |
+
def forward(self, hidden_states):
|
1419 |
+
for resnet in self.resnets:
|
1420 |
+
hidden_states, _ = resnet(hidden_states, temb=None)
|
1421 |
+
|
1422 |
+
if self.upsamplers is not None:
|
1423 |
+
for upsampler in self.upsamplers:
|
1424 |
+
hidden_states = upsampler(hidden_states)
|
1425 |
+
|
1426 |
+
return hidden_states
|
1427 |
+
|
1428 |
+
|
1429 |
+
class AttnUpDecoderBlock2D(nn.Module):
|
1430 |
+
def __init__(
|
1431 |
+
self,
|
1432 |
+
in_channels: int,
|
1433 |
+
out_channels: int,
|
1434 |
+
dropout: float = 0.0,
|
1435 |
+
num_layers: int = 1,
|
1436 |
+
resnet_eps: float = 1e-6,
|
1437 |
+
resnet_time_scale_shift: str = "default",
|
1438 |
+
resnet_act_fn: str = "swish",
|
1439 |
+
resnet_groups: int = 32,
|
1440 |
+
resnet_pre_norm: bool = True,
|
1441 |
+
attn_num_head_channels=1,
|
1442 |
+
output_scale_factor=1.0,
|
1443 |
+
add_upsample=True,
|
1444 |
+
):
|
1445 |
+
super().__init__()
|
1446 |
+
resnets = []
|
1447 |
+
attentions = []
|
1448 |
+
|
1449 |
+
for i in range(num_layers):
|
1450 |
+
input_channels = in_channels if i == 0 else out_channels
|
1451 |
+
|
1452 |
+
resnets.append(
|
1453 |
+
ResnetBlock2D(
|
1454 |
+
in_channels=input_channels,
|
1455 |
+
out_channels=out_channels,
|
1456 |
+
temb_channels=None,
|
1457 |
+
eps=resnet_eps,
|
1458 |
+
groups=resnet_groups,
|
1459 |
+
dropout=dropout,
|
1460 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1461 |
+
non_linearity=resnet_act_fn,
|
1462 |
+
output_scale_factor=output_scale_factor,
|
1463 |
+
pre_norm=resnet_pre_norm,
|
1464 |
+
)
|
1465 |
+
)
|
1466 |
+
attentions.append(
|
1467 |
+
AttentionBlock(
|
1468 |
+
out_channels,
|
1469 |
+
num_head_channels=attn_num_head_channels,
|
1470 |
+
rescale_output_factor=output_scale_factor,
|
1471 |
+
eps=resnet_eps,
|
1472 |
+
norm_num_groups=resnet_groups,
|
1473 |
+
)
|
1474 |
+
)
|
1475 |
+
|
1476 |
+
self.attentions = nn.ModuleList(attentions)
|
1477 |
+
self.resnets = nn.ModuleList(resnets)
|
1478 |
+
|
1479 |
+
if add_upsample:
|
1480 |
+
self.upsamplers = nn.ModuleList(
|
1481 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1482 |
+
else:
|
1483 |
+
self.upsamplers = None
|
1484 |
+
|
1485 |
+
def forward(self, hidden_states):
|
1486 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1487 |
+
hidden_states, _ = resnet(hidden_states, temb=None)
|
1488 |
+
hidden_states = attn(hidden_states)
|
1489 |
+
|
1490 |
+
if self.upsamplers is not None:
|
1491 |
+
for upsampler in self.upsamplers:
|
1492 |
+
hidden_states = upsampler(hidden_states)
|
1493 |
+
|
1494 |
+
return hidden_states
|
1495 |
+
|
1496 |
+
|
1497 |
+
class AttnSkipUpBlock2D(nn.Module):
|
1498 |
+
def __init__(
|
1499 |
+
self,
|
1500 |
+
in_channels: int,
|
1501 |
+
prev_output_channel: int,
|
1502 |
+
out_channels: int,
|
1503 |
+
temb_channels: int,
|
1504 |
+
dropout: float = 0.0,
|
1505 |
+
num_layers: int = 1,
|
1506 |
+
resnet_eps: float = 1e-6,
|
1507 |
+
resnet_time_scale_shift: str = "default",
|
1508 |
+
resnet_act_fn: str = "swish",
|
1509 |
+
resnet_pre_norm: bool = True,
|
1510 |
+
attn_num_head_channels=1,
|
1511 |
+
attention_type="default",
|
1512 |
+
output_scale_factor=np.sqrt(2.0),
|
1513 |
+
upsample_padding=1,
|
1514 |
+
add_upsample=True,
|
1515 |
+
):
|
1516 |
+
super().__init__()
|
1517 |
+
self.attentions = nn.ModuleList([])
|
1518 |
+
self.resnets = nn.ModuleList([])
|
1519 |
+
|
1520 |
+
self.attention_type = attention_type
|
1521 |
+
|
1522 |
+
for i in range(num_layers):
|
1523 |
+
res_skip_channels = in_channels if (
|
1524 |
+
i == num_layers - 1) else out_channels
|
1525 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1526 |
+
|
1527 |
+
self.resnets.append(
|
1528 |
+
ResnetBlock2D(
|
1529 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1530 |
+
out_channels=out_channels,
|
1531 |
+
temb_channels=temb_channels,
|
1532 |
+
eps=resnet_eps,
|
1533 |
+
groups=min(resnet_in_channels +
|
1534 |
+
res_skip_channels // 4, 32),
|
1535 |
+
groups_out=min(out_channels // 4, 32),
|
1536 |
+
dropout=dropout,
|
1537 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1538 |
+
non_linearity=resnet_act_fn,
|
1539 |
+
output_scale_factor=output_scale_factor,
|
1540 |
+
pre_norm=resnet_pre_norm,
|
1541 |
+
)
|
1542 |
+
)
|
1543 |
+
|
1544 |
+
self.attentions.append(
|
1545 |
+
AttentionBlock(
|
1546 |
+
out_channels,
|
1547 |
+
num_head_channels=attn_num_head_channels,
|
1548 |
+
rescale_output_factor=output_scale_factor,
|
1549 |
+
eps=resnet_eps,
|
1550 |
+
)
|
1551 |
+
)
|
1552 |
+
|
1553 |
+
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
1554 |
+
if add_upsample:
|
1555 |
+
self.resnet_up = ResnetBlock2D(
|
1556 |
+
in_channels=out_channels,
|
1557 |
+
out_channels=out_channels,
|
1558 |
+
temb_channels=temb_channels,
|
1559 |
+
eps=resnet_eps,
|
1560 |
+
groups=min(out_channels // 4, 32),
|
1561 |
+
groups_out=min(out_channels // 4, 32),
|
1562 |
+
dropout=dropout,
|
1563 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1564 |
+
non_linearity=resnet_act_fn,
|
1565 |
+
output_scale_factor=output_scale_factor,
|
1566 |
+
pre_norm=resnet_pre_norm,
|
1567 |
+
use_in_shortcut=True,
|
1568 |
+
up=True,
|
1569 |
+
kernel="fir",
|
1570 |
+
)
|
1571 |
+
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(
|
1572 |
+
3, 3), stride=(1, 1), padding=(1, 1))
|
1573 |
+
self.skip_norm = torch.nn.GroupNorm(
|
1574 |
+
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
|
1575 |
+
)
|
1576 |
+
self.act = nn.SiLU()
|
1577 |
+
else:
|
1578 |
+
self.resnet_up = None
|
1579 |
+
self.skip_conv = None
|
1580 |
+
self.skip_norm = None
|
1581 |
+
self.act = None
|
1582 |
+
|
1583 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
1584 |
+
for resnet in self.resnets:
|
1585 |
+
# pop res hidden states
|
1586 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1587 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1588 |
+
hidden_states = torch.cat(
|
1589 |
+
[hidden_states, res_hidden_states], dim=1)
|
1590 |
+
|
1591 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
1592 |
+
|
1593 |
+
hidden_states = self.attentions[0](hidden_states)
|
1594 |
+
|
1595 |
+
if skip_sample is not None:
|
1596 |
+
skip_sample = self.upsampler(skip_sample)
|
1597 |
+
else:
|
1598 |
+
skip_sample = 0
|
1599 |
+
|
1600 |
+
if self.resnet_up is not None:
|
1601 |
+
skip_sample_states = self.skip_norm(hidden_states)
|
1602 |
+
skip_sample_states = self.act(skip_sample_states)
|
1603 |
+
skip_sample_states = self.skip_conv(skip_sample_states)
|
1604 |
+
|
1605 |
+
skip_sample = skip_sample + skip_sample_states
|
1606 |
+
|
1607 |
+
hidden_states = self.resnet_up(hidden_states, temb)
|
1608 |
+
|
1609 |
+
return hidden_states, skip_sample
|
1610 |
+
|
1611 |
+
|
1612 |
+
class SkipUpBlock2D(nn.Module):
|
1613 |
+
def __init__(
|
1614 |
+
self,
|
1615 |
+
in_channels: int,
|
1616 |
+
prev_output_channel: int,
|
1617 |
+
out_channels: int,
|
1618 |
+
temb_channels: int,
|
1619 |
+
dropout: float = 0.0,
|
1620 |
+
num_layers: int = 1,
|
1621 |
+
resnet_eps: float = 1e-6,
|
1622 |
+
resnet_time_scale_shift: str = "default",
|
1623 |
+
resnet_act_fn: str = "swish",
|
1624 |
+
resnet_pre_norm: bool = True,
|
1625 |
+
output_scale_factor=np.sqrt(2.0),
|
1626 |
+
add_upsample=True,
|
1627 |
+
upsample_padding=1,
|
1628 |
+
):
|
1629 |
+
super().__init__()
|
1630 |
+
self.resnets = nn.ModuleList([])
|
1631 |
+
|
1632 |
+
for i in range(num_layers):
|
1633 |
+
res_skip_channels = in_channels if (
|
1634 |
+
i == num_layers - 1) else out_channels
|
1635 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1636 |
+
|
1637 |
+
self.resnets.append(
|
1638 |
+
ResnetBlock2D(
|
1639 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1640 |
+
out_channels=out_channels,
|
1641 |
+
temb_channels=temb_channels,
|
1642 |
+
eps=resnet_eps,
|
1643 |
+
groups=min(
|
1644 |
+
(resnet_in_channels + res_skip_channels) // 4, 32),
|
1645 |
+
groups_out=min(out_channels // 4, 32),
|
1646 |
+
dropout=dropout,
|
1647 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1648 |
+
non_linearity=resnet_act_fn,
|
1649 |
+
output_scale_factor=output_scale_factor,
|
1650 |
+
pre_norm=resnet_pre_norm,
|
1651 |
+
)
|
1652 |
+
)
|
1653 |
+
|
1654 |
+
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
1655 |
+
if add_upsample:
|
1656 |
+
self.resnet_up = ResnetBlock2D(
|
1657 |
+
in_channels=out_channels,
|
1658 |
+
out_channels=out_channels,
|
1659 |
+
temb_channels=temb_channels,
|
1660 |
+
eps=resnet_eps,
|
1661 |
+
groups=min(out_channels // 4, 32),
|
1662 |
+
groups_out=min(out_channels // 4, 32),
|
1663 |
+
dropout=dropout,
|
1664 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1665 |
+
non_linearity=resnet_act_fn,
|
1666 |
+
output_scale_factor=output_scale_factor,
|
1667 |
+
pre_norm=resnet_pre_norm,
|
1668 |
+
use_in_shortcut=True,
|
1669 |
+
up=True,
|
1670 |
+
kernel="fir",
|
1671 |
+
)
|
1672 |
+
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(
|
1673 |
+
3, 3), stride=(1, 1), padding=(1, 1))
|
1674 |
+
self.skip_norm = torch.nn.GroupNorm(
|
1675 |
+
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
|
1676 |
+
)
|
1677 |
+
self.act = nn.SiLU()
|
1678 |
+
else:
|
1679 |
+
self.resnet_up = None
|
1680 |
+
self.skip_conv = None
|
1681 |
+
self.skip_norm = None
|
1682 |
+
self.act = None
|
1683 |
+
|
1684 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
1685 |
+
for resnet in self.resnets:
|
1686 |
+
# pop res hidden states
|
1687 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1688 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1689 |
+
hidden_states = torch.cat(
|
1690 |
+
[hidden_states, res_hidden_states], dim=1)
|
1691 |
+
|
1692 |
+
hidden_states, _ = resnet(hidden_states, temb)
|
1693 |
+
|
1694 |
+
if skip_sample is not None:
|
1695 |
+
skip_sample = self.upsampler(skip_sample)
|
1696 |
+
else:
|
1697 |
+
skip_sample = 0
|
1698 |
+
|
1699 |
+
if self.resnet_up is not None:
|
1700 |
+
skip_sample_states = self.skip_norm(hidden_states)
|
1701 |
+
skip_sample_states = self.act(skip_sample_states)
|
1702 |
+
skip_sample_states = self.skip_conv(skip_sample_states)
|
1703 |
+
|
1704 |
+
skip_sample = skip_sample + skip_sample_states
|
1705 |
+
|
1706 |
+
hidden_states = self.resnet_up(hidden_states, temb)
|
1707 |
+
|
1708 |
+
return hidden_states, skip_sample
|
1709 |
+
|
1710 |
+
|
1711 |
+
class ResnetBlock2D(nn.Module):
|
1712 |
+
def __init__(
|
1713 |
+
self,
|
1714 |
+
*,
|
1715 |
+
in_channels,
|
1716 |
+
out_channels=None,
|
1717 |
+
conv_shortcut=False,
|
1718 |
+
dropout=0.0,
|
1719 |
+
temb_channels=512,
|
1720 |
+
groups=32,
|
1721 |
+
groups_out=None,
|
1722 |
+
pre_norm=True,
|
1723 |
+
eps=1e-6,
|
1724 |
+
non_linearity="swish",
|
1725 |
+
time_embedding_norm="default",
|
1726 |
+
kernel=None,
|
1727 |
+
output_scale_factor=1.0,
|
1728 |
+
use_in_shortcut=None,
|
1729 |
+
up=False,
|
1730 |
+
down=False,
|
1731 |
+
):
|
1732 |
+
super().__init__()
|
1733 |
+
self.pre_norm = pre_norm
|
1734 |
+
self.pre_norm = True
|
1735 |
+
self.in_channels = in_channels
|
1736 |
+
out_channels = in_channels if out_channels is None else out_channels
|
1737 |
+
self.out_channels = out_channels
|
1738 |
+
self.use_conv_shortcut = conv_shortcut
|
1739 |
+
self.time_embedding_norm = time_embedding_norm
|
1740 |
+
self.up = up
|
1741 |
+
self.down = down
|
1742 |
+
self.output_scale_factor = output_scale_factor
|
1743 |
+
|
1744 |
+
if groups_out is None:
|
1745 |
+
groups_out = groups
|
1746 |
+
|
1747 |
+
self.norm1 = torch.nn.GroupNorm(
|
1748 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
1749 |
+
|
1750 |
+
self.conv1 = torch.nn.Conv2d(
|
1751 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
1752 |
+
|
1753 |
+
if temb_channels is not None:
|
1754 |
+
if self.time_embedding_norm == "default":
|
1755 |
+
time_emb_proj_out_channels = out_channels
|
1756 |
+
elif self.time_embedding_norm == "scale_shift":
|
1757 |
+
time_emb_proj_out_channels = out_channels * 2
|
1758 |
+
else:
|
1759 |
+
raise ValueError(
|
1760 |
+
f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
1761 |
+
|
1762 |
+
self.time_emb_proj = torch.nn.Linear(
|
1763 |
+
temb_channels, time_emb_proj_out_channels)
|
1764 |
+
else:
|
1765 |
+
self.time_emb_proj = None
|
1766 |
+
|
1767 |
+
self.norm2 = torch.nn.GroupNorm(
|
1768 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
1769 |
+
self.dropout = torch.nn.Dropout(dropout)
|
1770 |
+
self.conv2 = torch.nn.Conv2d(
|
1771 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
1772 |
+
|
1773 |
+
if non_linearity == "swish":
|
1774 |
+
self.nonlinearity = lambda x: F.silu(x)
|
1775 |
+
elif non_linearity == "mish":
|
1776 |
+
self.nonlinearity = Mish()
|
1777 |
+
elif non_linearity == "silu":
|
1778 |
+
self.nonlinearity = nn.SiLU()
|
1779 |
+
|
1780 |
+
self.upsample = self.downsample = None
|
1781 |
+
if self.up:
|
1782 |
+
if kernel == "fir":
|
1783 |
+
fir_kernel = (1, 3, 3, 1)
|
1784 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
1785 |
+
elif kernel == "sde_vp":
|
1786 |
+
self.upsample = partial(
|
1787 |
+
F.interpolate, scale_factor=2.0, mode="nearest")
|
1788 |
+
else:
|
1789 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
1790 |
+
elif self.down:
|
1791 |
+
if kernel == "fir":
|
1792 |
+
fir_kernel = (1, 3, 3, 1)
|
1793 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
1794 |
+
elif kernel == "sde_vp":
|
1795 |
+
self.downsample = partial(
|
1796 |
+
F.avg_pool2d, kernel_size=2, stride=2)
|
1797 |
+
else:
|
1798 |
+
self.downsample = Downsample2D(
|
1799 |
+
in_channels, use_conv=False, padding=1, name="op")
|
1800 |
+
|
1801 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
1802 |
+
|
1803 |
+
self.conv_shortcut = None
|
1804 |
+
if self.use_in_shortcut:
|
1805 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
1806 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
1807 |
+
|
1808 |
+
def forward(self, input_tensor, temb, inject_states=None):
|
1809 |
+
hidden_states = input_tensor
|
1810 |
+
|
1811 |
+
hidden_states = self.norm1(hidden_states)
|
1812 |
+
hidden_states = self.nonlinearity(hidden_states)
|
1813 |
+
|
1814 |
+
if self.upsample is not None:
|
1815 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
1816 |
+
if hidden_states.shape[0] >= 64:
|
1817 |
+
input_tensor = input_tensor.contiguous()
|
1818 |
+
hidden_states = hidden_states.contiguous()
|
1819 |
+
input_tensor = self.upsample(input_tensor)
|
1820 |
+
hidden_states = self.upsample(hidden_states)
|
1821 |
+
elif self.downsample is not None:
|
1822 |
+
input_tensor = self.downsample(input_tensor)
|
1823 |
+
hidden_states = self.downsample(hidden_states)
|
1824 |
+
|
1825 |
+
hidden_states = self.conv1(hidden_states)
|
1826 |
+
|
1827 |
+
if temb is not None:
|
1828 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[
|
1829 |
+
:, :, None, None]
|
1830 |
+
|
1831 |
+
if temb is not None and self.time_embedding_norm == "default":
|
1832 |
+
hidden_states = hidden_states + temb
|
1833 |
+
|
1834 |
+
hidden_states = self.norm2(hidden_states)
|
1835 |
+
|
1836 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
1837 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
1838 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
1839 |
+
|
1840 |
+
hidden_states = self.nonlinearity(hidden_states)
|
1841 |
+
|
1842 |
+
hidden_states = self.dropout(hidden_states)
|
1843 |
+
hidden_states = self.conv2(hidden_states)
|
1844 |
+
|
1845 |
+
if self.conv_shortcut is not None:
|
1846 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
1847 |
+
|
1848 |
+
if inject_states is not None:
|
1849 |
+
output_tensor = (input_tensor + inject_states) / \
|
1850 |
+
self.output_scale_factor
|
1851 |
+
else:
|
1852 |
+
output_tensor = (input_tensor + hidden_states) / \
|
1853 |
+
self.output_scale_factor
|
1854 |
+
|
1855 |
+
return output_tensor, hidden_states
|
models/unet_2d_condition.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
25 |
+
from .unet_2d_blocks import (
|
26 |
+
CrossAttnDownBlock2D,
|
27 |
+
CrossAttnUpBlock2D,
|
28 |
+
DownBlock2D,
|
29 |
+
UNetMidBlock2DCrossAttn,
|
30 |
+
UpBlock2D,
|
31 |
+
get_down_block,
|
32 |
+
get_up_block,
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class UNet2DConditionOutput(BaseOutput):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
44 |
+
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
45 |
+
"""
|
46 |
+
|
47 |
+
sample: torch.FloatTensor
|
48 |
+
|
49 |
+
|
50 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
51 |
+
r"""
|
52 |
+
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
53 |
+
and returns sample shaped output.
|
54 |
+
|
55 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
56 |
+
implements for all the models (such as downloading or saving, etc.)
|
57 |
+
|
58 |
+
Parameters:
|
59 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
60 |
+
Height and width of input/output sample.
|
61 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
62 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
63 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
64 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
65 |
+
Whether to flip the sin to cos in the time embedding.
|
66 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
67 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
68 |
+
The tuple of downsample blocks to use.
|
69 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
70 |
+
The tuple of upsample blocks to use.
|
71 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
72 |
+
The tuple of output channels for each block.
|
73 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
74 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
75 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
76 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
77 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
78 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
79 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
80 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
81 |
+
"""
|
82 |
+
|
83 |
+
_supports_gradient_checkpointing = True
|
84 |
+
|
85 |
+
@register_to_config
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
sample_size: Optional[int] = None,
|
89 |
+
in_channels: int = 4,
|
90 |
+
out_channels: int = 4,
|
91 |
+
center_input_sample: bool = False,
|
92 |
+
flip_sin_to_cos: bool = True,
|
93 |
+
freq_shift: int = 0,
|
94 |
+
down_block_types: Tuple[str] = (
|
95 |
+
"CrossAttnDownBlock2D",
|
96 |
+
"CrossAttnDownBlock2D",
|
97 |
+
"CrossAttnDownBlock2D",
|
98 |
+
"DownBlock2D",
|
99 |
+
),
|
100 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
101 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
102 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
103 |
+
layers_per_block: int = 2,
|
104 |
+
downsample_padding: int = 1,
|
105 |
+
mid_block_scale_factor: float = 1,
|
106 |
+
act_fn: str = "silu",
|
107 |
+
norm_num_groups: int = 32,
|
108 |
+
norm_eps: float = 1e-5,
|
109 |
+
cross_attention_dim: int = 1280,
|
110 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
111 |
+
dual_cross_attention: bool = False,
|
112 |
+
use_linear_projection: bool = False,
|
113 |
+
num_class_embeds: Optional[int] = None,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
|
117 |
+
self.sample_size = sample_size
|
118 |
+
time_embed_dim = block_out_channels[0] * 4
|
119 |
+
# import ipdb;ipdb.set_trace()
|
120 |
+
|
121 |
+
# input
|
122 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
123 |
+
|
124 |
+
# time
|
125 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
126 |
+
timestep_input_dim = block_out_channels[0]
|
127 |
+
|
128 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
129 |
+
|
130 |
+
# class embedding
|
131 |
+
if num_class_embeds is not None:
|
132 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
133 |
+
|
134 |
+
self.down_blocks = nn.ModuleList([])
|
135 |
+
self.mid_block = None
|
136 |
+
self.up_blocks = nn.ModuleList([])
|
137 |
+
|
138 |
+
if isinstance(only_cross_attention, bool):
|
139 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
140 |
+
|
141 |
+
if isinstance(attention_head_dim, int):
|
142 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
143 |
+
|
144 |
+
# down
|
145 |
+
output_channel = block_out_channels[0]
|
146 |
+
for i, down_block_type in enumerate(down_block_types):
|
147 |
+
input_channel = output_channel
|
148 |
+
output_channel = block_out_channels[i]
|
149 |
+
is_final_block = i == len(block_out_channels) - 1
|
150 |
+
|
151 |
+
down_block = get_down_block(
|
152 |
+
down_block_type,
|
153 |
+
num_layers=layers_per_block,
|
154 |
+
in_channels=input_channel,
|
155 |
+
out_channels=output_channel,
|
156 |
+
temb_channels=time_embed_dim,
|
157 |
+
add_downsample=not is_final_block,
|
158 |
+
resnet_eps=norm_eps,
|
159 |
+
resnet_act_fn=act_fn,
|
160 |
+
resnet_groups=norm_num_groups,
|
161 |
+
cross_attention_dim=cross_attention_dim,
|
162 |
+
attn_num_head_channels=attention_head_dim[i],
|
163 |
+
downsample_padding=downsample_padding,
|
164 |
+
dual_cross_attention=dual_cross_attention,
|
165 |
+
use_linear_projection=use_linear_projection,
|
166 |
+
only_cross_attention=only_cross_attention[i],
|
167 |
+
)
|
168 |
+
self.down_blocks.append(down_block)
|
169 |
+
|
170 |
+
# mid
|
171 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
172 |
+
in_channels=block_out_channels[-1],
|
173 |
+
temb_channels=time_embed_dim,
|
174 |
+
resnet_eps=norm_eps,
|
175 |
+
resnet_act_fn=act_fn,
|
176 |
+
output_scale_factor=mid_block_scale_factor,
|
177 |
+
resnet_time_scale_shift="default",
|
178 |
+
cross_attention_dim=cross_attention_dim,
|
179 |
+
attn_num_head_channels=attention_head_dim[-1],
|
180 |
+
resnet_groups=norm_num_groups,
|
181 |
+
dual_cross_attention=dual_cross_attention,
|
182 |
+
use_linear_projection=use_linear_projection,
|
183 |
+
)
|
184 |
+
|
185 |
+
# count how many layers upsample the images
|
186 |
+
self.num_upsamplers = 0
|
187 |
+
|
188 |
+
# up
|
189 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
190 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
191 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
192 |
+
output_channel = reversed_block_out_channels[0]
|
193 |
+
for i, up_block_type in enumerate(up_block_types):
|
194 |
+
is_final_block = i == len(block_out_channels) - 1
|
195 |
+
|
196 |
+
prev_output_channel = output_channel
|
197 |
+
output_channel = reversed_block_out_channels[i]
|
198 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
199 |
+
|
200 |
+
# add upsample block for all BUT final layer
|
201 |
+
if not is_final_block:
|
202 |
+
add_upsample = True
|
203 |
+
self.num_upsamplers += 1
|
204 |
+
else:
|
205 |
+
add_upsample = False
|
206 |
+
|
207 |
+
up_block = get_up_block(
|
208 |
+
up_block_type,
|
209 |
+
num_layers=layers_per_block + 1,
|
210 |
+
in_channels=input_channel,
|
211 |
+
out_channels=output_channel,
|
212 |
+
prev_output_channel=prev_output_channel,
|
213 |
+
temb_channels=time_embed_dim,
|
214 |
+
add_upsample=add_upsample,
|
215 |
+
resnet_eps=norm_eps,
|
216 |
+
resnet_act_fn=act_fn,
|
217 |
+
resnet_groups=norm_num_groups,
|
218 |
+
cross_attention_dim=cross_attention_dim,
|
219 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
220 |
+
dual_cross_attention=dual_cross_attention,
|
221 |
+
use_linear_projection=use_linear_projection,
|
222 |
+
only_cross_attention=only_cross_attention[i],
|
223 |
+
)
|
224 |
+
self.up_blocks.append(up_block)
|
225 |
+
prev_output_channel = output_channel
|
226 |
+
|
227 |
+
# out
|
228 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
229 |
+
self.conv_act = nn.SiLU()
|
230 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
231 |
+
|
232 |
+
def set_attention_slice(self, slice_size):
|
233 |
+
head_dims = self.config.attention_head_dim
|
234 |
+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
235 |
+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
236 |
+
raise ValueError(
|
237 |
+
f"Make sure slice_size {slice_size} is a common divisor of "
|
238 |
+
f"the number of heads used in cross_attention: {head_dims}"
|
239 |
+
)
|
240 |
+
if slice_size is not None and slice_size > min(head_dims):
|
241 |
+
raise ValueError(
|
242 |
+
f"slice_size {slice_size} has to be smaller or equal to "
|
243 |
+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
244 |
+
)
|
245 |
+
|
246 |
+
for block in self.down_blocks:
|
247 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
248 |
+
block.set_attention_slice(slice_size)
|
249 |
+
|
250 |
+
self.mid_block.set_attention_slice(slice_size)
|
251 |
+
|
252 |
+
for block in self.up_blocks:
|
253 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
254 |
+
block.set_attention_slice(slice_size)
|
255 |
+
|
256 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
257 |
+
for block in self.down_blocks:
|
258 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
259 |
+
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
260 |
+
|
261 |
+
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
262 |
+
|
263 |
+
for block in self.up_blocks:
|
264 |
+
if hasattr(block, "attentions") and block.attentions is not None:
|
265 |
+
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
266 |
+
|
267 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
268 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
269 |
+
module.gradient_checkpointing = value
|
270 |
+
|
271 |
+
def forward(
|
272 |
+
self,
|
273 |
+
sample: torch.FloatTensor,
|
274 |
+
timestep: Union[torch.Tensor, float, int],
|
275 |
+
encoder_hidden_states: torch.Tensor,
|
276 |
+
class_labels: Optional[torch.Tensor] = None,
|
277 |
+
text_format_dict = {},
|
278 |
+
return_dict: bool = True,
|
279 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
280 |
+
r"""
|
281 |
+
Args:
|
282 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
283 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
284 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
|
285 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
286 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
290 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
291 |
+
returning a tuple, the first element is the sample tensor.
|
292 |
+
"""
|
293 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
294 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
295 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
296 |
+
# on the fly if necessary.
|
297 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
298 |
+
|
299 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
300 |
+
forward_upsample_size = False
|
301 |
+
upsample_size = None
|
302 |
+
|
303 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
304 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
305 |
+
forward_upsample_size = True
|
306 |
+
|
307 |
+
# 0. center input if necessary
|
308 |
+
if self.config.center_input_sample:
|
309 |
+
sample = 2 * sample - 1.0
|
310 |
+
|
311 |
+
# 1. time
|
312 |
+
timesteps = timestep
|
313 |
+
if not torch.is_tensor(timesteps):
|
314 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
315 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
316 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
317 |
+
timesteps = timesteps[None].to(sample.device)
|
318 |
+
|
319 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
320 |
+
timesteps = timesteps.expand(sample.shape[0])
|
321 |
+
|
322 |
+
t_emb = self.time_proj(timesteps)
|
323 |
+
|
324 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
325 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
326 |
+
# there might be better ways to encapsulate this.
|
327 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
328 |
+
emb = self.time_embedding(t_emb)
|
329 |
+
|
330 |
+
if self.config.num_class_embeds is not None:
|
331 |
+
if class_labels is None:
|
332 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
333 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
334 |
+
emb = emb + class_emb
|
335 |
+
|
336 |
+
# 2. pre-process
|
337 |
+
sample = self.conv_in(sample)
|
338 |
+
|
339 |
+
# 3. down
|
340 |
+
down_block_res_samples = (sample,)
|
341 |
+
for downsample_block in self.down_blocks:
|
342 |
+
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
343 |
+
if isinstance(downsample_block, CrossAttnDownBlock2D):
|
344 |
+
sample, res_samples = downsample_block(
|
345 |
+
hidden_states=sample,
|
346 |
+
temb=emb,
|
347 |
+
encoder_hidden_states=encoder_hidden_states,
|
348 |
+
text_format_dict=text_format_dict
|
349 |
+
)
|
350 |
+
else:
|
351 |
+
sample, res_samples = downsample_block(
|
352 |
+
hidden_states=sample,
|
353 |
+
temb=emb,
|
354 |
+
encoder_hidden_states=encoder_hidden_states,
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
if isinstance(downsample_block, CrossAttnDownBlock2D):
|
358 |
+
import ipdb;ipdb.set_trace()
|
359 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
360 |
+
down_block_res_samples += res_samples
|
361 |
+
|
362 |
+
# 4. mid
|
363 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states,
|
364 |
+
text_format_dict=text_format_dict)
|
365 |
+
|
366 |
+
# 5. up
|
367 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
368 |
+
is_final_block = i == len(self.up_blocks) - 1
|
369 |
+
|
370 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
371 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
372 |
+
|
373 |
+
# if we have not reached the final block and need to forward the
|
374 |
+
# upsample size, we do it here
|
375 |
+
if not is_final_block and forward_upsample_size:
|
376 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
377 |
+
|
378 |
+
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
379 |
+
if isinstance(upsample_block, CrossAttnUpBlock2D):
|
380 |
+
sample = upsample_block(
|
381 |
+
hidden_states=sample,
|
382 |
+
temb=emb,
|
383 |
+
res_hidden_states_tuple=res_samples,
|
384 |
+
encoder_hidden_states=encoder_hidden_states,
|
385 |
+
upsample_size=upsample_size,
|
386 |
+
text_format_dict=text_format_dict
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
sample = upsample_block(
|
390 |
+
hidden_states=sample,
|
391 |
+
temb=emb,
|
392 |
+
res_hidden_states_tuple=res_samples,
|
393 |
+
encoder_hidden_states=encoder_hidden_states,
|
394 |
+
upsample_size=upsample_size,
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
if isinstance(upsample_block, CrossAttnUpBlock2D):
|
398 |
+
upsample_block.attentions
|
399 |
+
import ipdb;ipdb.set_trace()
|
400 |
+
sample = upsample_block(
|
401 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
402 |
+
)
|
403 |
+
# 6. post-process
|
404 |
+
sample = self.conv_norm_out(sample)
|
405 |
+
sample = self.conv_act(sample)
|
406 |
+
sample = self.conv_out(sample)
|
407 |
+
|
408 |
+
if not return_dict:
|
409 |
+
return (sample,)
|
410 |
+
|
411 |
+
return UNet2DConditionOutput(sample=sample)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
2 |
+
torch==1.13.1
|
3 |
+
torchvision==0.14.1
|
4 |
+
diffusers==0.12.1
|
5 |
+
transformers==4.26.0
|
6 |
+
numpy==1.24.2
|
7 |
+
seaborn==0.12.2
|
8 |
+
accelerate==0.16.0
|
9 |
+
scikit-learn==0.24.1
|
rich-text-to-json-iframe.html
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<title>Rich Text to JSON</title>
|
6 |
+
<link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
|
7 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
|
8 |
+
<link rel="stylesheet" type="text/css"
|
9 |
+
href="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.css">
|
10 |
+
<link rel="stylesheet"
|
11 |
+
href='https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap'>
|
12 |
+
<style>
|
13 |
+
html,
|
14 |
+
body {
|
15 |
+
background-color: white;
|
16 |
+
margin: 0;
|
17 |
+
}
|
18 |
+
|
19 |
+
/* Set default font-family */
|
20 |
+
.ql-snow .ql-tooltip::before {
|
21 |
+
content: "Footnote";
|
22 |
+
line-height: 26px;
|
23 |
+
margin-right: 8px;
|
24 |
+
}
|
25 |
+
|
26 |
+
.ql-snow .ql-tooltip[data-mode=link]::before {
|
27 |
+
content: "Enter footnote:";
|
28 |
+
}
|
29 |
+
|
30 |
+
.row {
|
31 |
+
margin-top: 15px;
|
32 |
+
margin-left: 0px;
|
33 |
+
margin-bottom: 15px;
|
34 |
+
}
|
35 |
+
|
36 |
+
.btn-primary {
|
37 |
+
color: #ffffff;
|
38 |
+
background-color: #2780e3;
|
39 |
+
border-color: #2780e3;
|
40 |
+
}
|
41 |
+
|
42 |
+
.btn-primary:hover {
|
43 |
+
color: #ffffff;
|
44 |
+
background-color: #1967be;
|
45 |
+
border-color: #1862b5;
|
46 |
+
}
|
47 |
+
|
48 |
+
.btn {
|
49 |
+
display: inline-block;
|
50 |
+
margin-bottom: 0;
|
51 |
+
font-weight: normal;
|
52 |
+
text-align: center;
|
53 |
+
vertical-align: middle;
|
54 |
+
touch-action: manipulation;
|
55 |
+
cursor: pointer;
|
56 |
+
background-image: none;
|
57 |
+
border: 1px solid transparent;
|
58 |
+
white-space: nowrap;
|
59 |
+
padding: 10px 18px;
|
60 |
+
font-size: 15px;
|
61 |
+
line-height: 1.42857143;
|
62 |
+
border-radius: 0;
|
63 |
+
user-select: none;
|
64 |
+
}
|
65 |
+
|
66 |
+
#standalone-container {
|
67 |
+
width: 100%;
|
68 |
+
background-color: #ffffff;
|
69 |
+
}
|
70 |
+
|
71 |
+
#editor-container {
|
72 |
+
font-family: "Aref Ruqaa";
|
73 |
+
font-size: 18px;
|
74 |
+
height: 250px;
|
75 |
+
width: 100%;
|
76 |
+
}
|
77 |
+
|
78 |
+
#toolbar-container {
|
79 |
+
font-family: "Aref Ruqaa";
|
80 |
+
display: flex;
|
81 |
+
flex-wrap: wrap;
|
82 |
+
}
|
83 |
+
|
84 |
+
#json-container {
|
85 |
+
max-width: 720px;
|
86 |
+
}
|
87 |
+
|
88 |
+
/* Set dropdown font-families */
|
89 |
+
#toolbar-container .ql-font span[data-label="Base"]::before {
|
90 |
+
font-family: "Aref Ruqaa";
|
91 |
+
}
|
92 |
+
|
93 |
+
#toolbar-container .ql-font span[data-label="Claude Monet"]::before {
|
94 |
+
font-family: "Mirza";
|
95 |
+
}
|
96 |
+
|
97 |
+
#toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
|
98 |
+
font-family: "Roboto";
|
99 |
+
}
|
100 |
+
|
101 |
+
#toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
|
102 |
+
font-family: "Comic Sans MS";
|
103 |
+
}
|
104 |
+
|
105 |
+
#toolbar-container .ql-font span[data-label="Pop Art"]::before {
|
106 |
+
font-family: "sofia";
|
107 |
+
}
|
108 |
+
|
109 |
+
#toolbar-container .ql-font span[data-label="Van Gogh"]::before {
|
110 |
+
font-family: "slabo 27px";
|
111 |
+
}
|
112 |
+
|
113 |
+
#toolbar-container .ql-font span[data-label="Pixel Art"]::before {
|
114 |
+
font-family: "inconsolata";
|
115 |
+
}
|
116 |
+
|
117 |
+
#toolbar-container .ql-font span[data-label="Rembrandt"]::before {
|
118 |
+
font-family: "ubuntu";
|
119 |
+
}
|
120 |
+
|
121 |
+
#toolbar-container .ql-font span[data-label="Cubism"]::before {
|
122 |
+
font-family: "Akronim";
|
123 |
+
}
|
124 |
+
|
125 |
+
#toolbar-container .ql-font span[data-label="Neon Art"]::before {
|
126 |
+
font-family: "Monoton";
|
127 |
+
}
|
128 |
+
|
129 |
+
/* Set content font-families */
|
130 |
+
.ql-font-mirza {
|
131 |
+
font-family: "Mirza";
|
132 |
+
}
|
133 |
+
|
134 |
+
.ql-font-roboto {
|
135 |
+
font-family: "Roboto";
|
136 |
+
}
|
137 |
+
|
138 |
+
.ql-font-cursive {
|
139 |
+
font-family: "Comic Sans MS";
|
140 |
+
}
|
141 |
+
|
142 |
+
.ql-font-sofia {
|
143 |
+
font-family: "sofia";
|
144 |
+
}
|
145 |
+
|
146 |
+
.ql-font-slabo {
|
147 |
+
font-family: "slabo 27px";
|
148 |
+
}
|
149 |
+
|
150 |
+
.ql-font-inconsolata {
|
151 |
+
font-family: "inconsolata";
|
152 |
+
}
|
153 |
+
|
154 |
+
.ql-font-ubuntu {
|
155 |
+
font-family: "ubuntu";
|
156 |
+
}
|
157 |
+
|
158 |
+
.ql-font-Akronim {
|
159 |
+
font-family: "Akronim";
|
160 |
+
}
|
161 |
+
|
162 |
+
.ql-font-Monoton {
|
163 |
+
font-family: "Monoton";
|
164 |
+
}
|
165 |
+
|
166 |
+
.ql-color .ql-picker-options [data-value=Color-Picker] {
|
167 |
+
background: none !important;
|
168 |
+
width: 100% !important;
|
169 |
+
height: 20px !important;
|
170 |
+
text-align: center;
|
171 |
+
}
|
172 |
+
|
173 |
+
.ql-color .ql-picker-options [data-value=Color-Picker]:before {
|
174 |
+
content: 'Color Picker';
|
175 |
+
}
|
176 |
+
|
177 |
+
.ql-color .ql-picker-options [data-value=Color-Picker]:hover {
|
178 |
+
border-color: transparent !important;
|
179 |
+
}
|
180 |
+
</style>
|
181 |
+
</head>
|
182 |
+
|
183 |
+
<body>
|
184 |
+
<div id="standalone-container">
|
185 |
+
<div id="toolbar-container">
|
186 |
+
<span class="ql-formats">
|
187 |
+
<select class="ql-font">
|
188 |
+
<option selected>Base</option>
|
189 |
+
<option value="mirza">Claude Monet</option>
|
190 |
+
<option value="roboto">Ukiyoe</option>
|
191 |
+
<option value="cursive">Cyber Punk</option>
|
192 |
+
<option value="sofia">Pop Art</option>
|
193 |
+
<option value="slabo">Van Gogh</option>
|
194 |
+
<option value="inconsolata">Pixel Art</option>
|
195 |
+
<option value="ubuntu">Rembrandt</option>
|
196 |
+
<option value="Akronim">Cubism</option>
|
197 |
+
<option value="Monoton">Neon Art</option>
|
198 |
+
</select>
|
199 |
+
<select class="ql-size">
|
200 |
+
<option value="18px">Small</option>
|
201 |
+
<option selected>Normal</option>
|
202 |
+
<option value="32px">Large</option>
|
203 |
+
<option value="50px">Huge</option>
|
204 |
+
</select>
|
205 |
+
</span>
|
206 |
+
<span class="ql-formats">
|
207 |
+
<button class="ql-strike"></button>
|
208 |
+
</span>
|
209 |
+
<!-- <span class="ql-formats">
|
210 |
+
<button class="ql-bold"></button>
|
211 |
+
<button class="ql-italic"></button>
|
212 |
+
<button class="ql-underline"></button>
|
213 |
+
</span> -->
|
214 |
+
<span class="ql-formats">
|
215 |
+
<select class="ql-color">
|
216 |
+
<option value="Color-Picker"></option>
|
217 |
+
</select>
|
218 |
+
<!-- <select class="ql-background"></select> -->
|
219 |
+
</span>
|
220 |
+
<!-- <span class="ql-formats">
|
221 |
+
<button class="ql-script" value="sub"></button>
|
222 |
+
<button class="ql-script" value="super"></button>
|
223 |
+
</span>
|
224 |
+
<span class="ql-formats">
|
225 |
+
<button class="ql-header" value="1"></button>
|
226 |
+
<button class="ql-header" value="2"></button>
|
227 |
+
<button class="ql-blockquote"></button>
|
228 |
+
<button class="ql-code-block"></button>
|
229 |
+
</span>
|
230 |
+
<span class="ql-formats">
|
231 |
+
<button class="ql-list" value="ordered"></button>
|
232 |
+
<button class="ql-list" value="bullet"></button>
|
233 |
+
<button class="ql-indent" value="-1"></button>
|
234 |
+
<button class="ql-indent" value="+1"></button>
|
235 |
+
</span>
|
236 |
+
<span class="ql-formats">
|
237 |
+
<button class="ql-direction" value="rtl"></button>
|
238 |
+
<select class="ql-align"></select>
|
239 |
+
</span>
|
240 |
+
<span class="ql-formats">
|
241 |
+
<button class="ql-link"></button>
|
242 |
+
<button class="ql-image"></button>
|
243 |
+
<button class="ql-video"></button>
|
244 |
+
<button class="ql-formula"></button>
|
245 |
+
</span> -->
|
246 |
+
<span class="ql-formats">
|
247 |
+
<button class="ql-link"></button>
|
248 |
+
</span>
|
249 |
+
<span class="ql-formats">
|
250 |
+
<button class="ql-clean"></button>
|
251 |
+
</span>
|
252 |
+
</div>
|
253 |
+
<div id="editor-container" style="height:300px;"></div>
|
254 |
+
</div>
|
255 |
+
<script src="https://cdn.quilljs.com/1.3.6/quill.min.js"></script>
|
256 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.1.0/jquery.min.js"></script>
|
257 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.js"></script>
|
258 |
+
<script>
|
259 |
+
|
260 |
+
// Register the customs format with Quill
|
261 |
+
const Font = Quill.import('formats/font');
|
262 |
+
Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
|
263 |
+
const Link = Quill.import('formats/link');
|
264 |
+
Link.sanitize = function (url) {
|
265 |
+
// modify url if desired
|
266 |
+
return url;
|
267 |
+
}
|
268 |
+
const SizeStyle = Quill.import('attributors/style/size');
|
269 |
+
SizeStyle.whitelist = ['10px', '18px', '20px', '32px', '50px', '60px', '64px', '70px'];
|
270 |
+
Quill.register(SizeStyle, true);
|
271 |
+
Quill.register(Link, true);
|
272 |
+
Quill.register(Font, true);
|
273 |
+
const icons = Quill.import('ui/icons');
|
274 |
+
icons['link'] = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
|
275 |
+
const quill = new Quill('#editor-container', {
|
276 |
+
modules: {
|
277 |
+
toolbar: {
|
278 |
+
container: '#toolbar-container',
|
279 |
+
},
|
280 |
+
},
|
281 |
+
theme: 'snow'
|
282 |
+
});
|
283 |
+
var toolbar = quill.getModule('toolbar');
|
284 |
+
$(toolbar.container).find('.ql-color').spectrum({
|
285 |
+
preferredFormat: "rgb",
|
286 |
+
showInput: true,
|
287 |
+
showInitial: true,
|
288 |
+
showPalette: true,
|
289 |
+
showSelectionPalette: true,
|
290 |
+
palette: [
|
291 |
+
["#000", "#444", "#666", "#999", "#ccc", "#eee", "#f3f3f3", "#fff"],
|
292 |
+
["#f00", "#f90", "#ff0", "#0f0", "#0ff", "#00f", "#90f", "#f0f"],
|
293 |
+
["#ea9999", "#f9cb9c", "#ffe599", "#b6d7a8", "#a2c4c9", "#9fc5e8", "#b4a7d6", "#d5a6bd"],
|
294 |
+
["#e06666", "#f6b26b", "#ffd966", "#93c47d", "#76a5af", "#6fa8dc", "#8e7cc3", "#c27ba0"],
|
295 |
+
["#c00", "#e69138", "#f1c232", "#6aa84f", "#45818e", "#3d85c6", "#674ea7", "#a64d79"],
|
296 |
+
["#900", "#b45f06", "#bf9000", "#38761d", "#134f5c", "#0b5394", "#351c75", "#741b47"],
|
297 |
+
["#600", "#783f04", "#7f6000", "#274e13", "#0c343d", "#073763", "#20124d", "#4c1130"]
|
298 |
+
],
|
299 |
+
change: function (color) {
|
300 |
+
var value = color.toHexString();
|
301 |
+
quill.format('color', value);
|
302 |
+
}
|
303 |
+
});
|
304 |
+
|
305 |
+
quill.on('text-change', () => {
|
306 |
+
// keep qull data inside _data to communicate with Gradio
|
307 |
+
document.body._data = quill.getContents()
|
308 |
+
})
|
309 |
+
function setQuillContents(content) {
|
310 |
+
quill.setContents(content);
|
311 |
+
document.body._data = quill.getContents();
|
312 |
+
}
|
313 |
+
document.body.setQuillContents = setQuillContents
|
314 |
+
</script>
|
315 |
+
<script src="https://unpkg.com/@popperjs/core@2/dist/umd/popper.min.js"></script>
|
316 |
+
<script src="https://unpkg.com/tippy.js@6/dist/tippy-bundle.umd.js"></script>
|
317 |
+
<script>
|
318 |
+
// With the above scripts loaded, you can call `tippy()` with a CSS
|
319 |
+
// selector and a `content` prop:
|
320 |
+
tippy('.ql-font', {
|
321 |
+
content: 'Add a style to the token',
|
322 |
+
});
|
323 |
+
tippy('.ql-size', {
|
324 |
+
content: 'Reweight the token',
|
325 |
+
});
|
326 |
+
tippy('.ql-color', {
|
327 |
+
content: 'Pick a color for the token',
|
328 |
+
});
|
329 |
+
tippy('.ql-link', {
|
330 |
+
content: 'Clarify the token',
|
331 |
+
});
|
332 |
+
tippy('.ql-strike', {
|
333 |
+
content: 'Change the token weight to be negative',
|
334 |
+
});
|
335 |
+
tippy('.ql-clean', {
|
336 |
+
content: 'Remove all the formats',
|
337 |
+
});
|
338 |
+
</script>
|
339 |
+
</body>
|
340 |
+
|
341 |
+
</html>
|
rich-text-to-json.js
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class RichTextEditor extends HTMLElement {
|
2 |
+
constructor() {
|
3 |
+
super();
|
4 |
+
this.loadExternalScripts();
|
5 |
+
this.attachShadow({ mode: 'open' });
|
6 |
+
this.shadowRoot.innerHTML = `
|
7 |
+
${RichTextEditor.header()}
|
8 |
+
${RichTextEditor.template()}
|
9 |
+
`;
|
10 |
+
}
|
11 |
+
connectedCallback() {
|
12 |
+
this.myQuill = this.mountQuill();
|
13 |
+
}
|
14 |
+
loadExternalScripts() {
|
15 |
+
const links = ["https://cdn.quilljs.com/1.3.6/quill.snow.css", "https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css", "https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap"]
|
16 |
+
links.forEach(link => {
|
17 |
+
const css = document.createElement("link");
|
18 |
+
css.href = link;
|
19 |
+
css.rel = "stylesheet"
|
20 |
+
document.head.appendChild(css);
|
21 |
+
})
|
22 |
+
|
23 |
+
}
|
24 |
+
static template() {
|
25 |
+
return `
|
26 |
+
<div id="standalone-container">
|
27 |
+
<div id="toolbar-container">
|
28 |
+
<span class="ql-formats">
|
29 |
+
<select class="ql-font">
|
30 |
+
<option selected>Base</option>
|
31 |
+
<option value="mirza">Claude Monet</option>
|
32 |
+
<option value="roboto">Ukiyoe</option>
|
33 |
+
<option value="cursive">Cyber Punk</option>
|
34 |
+
<option value="sofia">Pop Art</option>
|
35 |
+
<option value="slabo">Van Gogh</option>
|
36 |
+
<option value="inconsolata">Pixel Art</option>
|
37 |
+
<option value="ubuntu">Rembrandt</option>
|
38 |
+
<option value="Akronim">Cubism</option>
|
39 |
+
<option value="Monoton">Neon Art</option>
|
40 |
+
</select>
|
41 |
+
<select class="ql-size">
|
42 |
+
<option value="18px">Small</option>
|
43 |
+
<option selected>Normal</option>
|
44 |
+
<option value="32px">Large</option>
|
45 |
+
<option value="50px">Huge</option>
|
46 |
+
</select>
|
47 |
+
</span>
|
48 |
+
<span class="ql-formats">
|
49 |
+
<button class="ql-strike"></button>
|
50 |
+
</span>
|
51 |
+
<!-- <span class="ql-formats">
|
52 |
+
<button class="ql-bold"></button>
|
53 |
+
<button class="ql-italic"></button>
|
54 |
+
<button class="ql-underline"></button>
|
55 |
+
</span> -->
|
56 |
+
<span class="ql-formats">
|
57 |
+
<select class="ql-color"></select>
|
58 |
+
<!-- <select class="ql-background"></select> -->
|
59 |
+
</span>
|
60 |
+
<!-- <span class="ql-formats">
|
61 |
+
<button class="ql-script" value="sub"></button>
|
62 |
+
<button class="ql-script" value="super"></button>
|
63 |
+
</span>
|
64 |
+
<span class="ql-formats">
|
65 |
+
<button class="ql-header" value="1"></button>
|
66 |
+
<button class="ql-header" value="2"></button>
|
67 |
+
<button class="ql-blockquote"></button>
|
68 |
+
<button class="ql-code-block"></button>
|
69 |
+
</span>
|
70 |
+
<span class="ql-formats">
|
71 |
+
<button class="ql-list" value="ordered"></button>
|
72 |
+
<button class="ql-list" value="bullet"></button>
|
73 |
+
<button class="ql-indent" value="-1"></button>
|
74 |
+
<button class="ql-indent" value="+1"></button>
|
75 |
+
</span>
|
76 |
+
<span class="ql-formats">
|
77 |
+
<button class="ql-direction" value="rtl"></button>
|
78 |
+
<select class="ql-align"></select>
|
79 |
+
</span>
|
80 |
+
<span class="ql-formats">
|
81 |
+
<button class="ql-link"></button>
|
82 |
+
<button class="ql-image"></button>
|
83 |
+
<button class="ql-video"></button>
|
84 |
+
<button class="ql-formula"></button>
|
85 |
+
</span> -->
|
86 |
+
<span class="ql-formats">
|
87 |
+
<button class="ql-link"></button>
|
88 |
+
</span>
|
89 |
+
<span class="ql-formats">
|
90 |
+
<button class="ql-clean"></button>
|
91 |
+
</span>
|
92 |
+
</div>
|
93 |
+
<div id="editor-container"></div>
|
94 |
+
</div>
|
95 |
+
`;
|
96 |
+
}
|
97 |
+
|
98 |
+
static header() {
|
99 |
+
return `
|
100 |
+
<link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
|
101 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
|
102 |
+
<style>
|
103 |
+
/* Set default font-family */
|
104 |
+
.ql-snow .ql-tooltip::before {
|
105 |
+
content: "Footnote";
|
106 |
+
line-height: 26px;
|
107 |
+
margin-right: 8px;
|
108 |
+
}
|
109 |
+
|
110 |
+
.ql-snow .ql-tooltip[data-mode=link]::before {
|
111 |
+
content: "Enter footnote:";
|
112 |
+
}
|
113 |
+
|
114 |
+
.row {
|
115 |
+
margin-top: 15px;
|
116 |
+
margin-left: 0px;
|
117 |
+
margin-bottom: 15px;
|
118 |
+
}
|
119 |
+
|
120 |
+
.btn-primary {
|
121 |
+
color: #ffffff;
|
122 |
+
background-color: #2780e3;
|
123 |
+
border-color: #2780e3;
|
124 |
+
}
|
125 |
+
|
126 |
+
.btn-primary:hover {
|
127 |
+
color: #ffffff;
|
128 |
+
background-color: #1967be;
|
129 |
+
border-color: #1862b5;
|
130 |
+
}
|
131 |
+
|
132 |
+
.btn {
|
133 |
+
display: inline-block;
|
134 |
+
margin-bottom: 0;
|
135 |
+
font-weight: normal;
|
136 |
+
text-align: center;
|
137 |
+
vertical-align: middle;
|
138 |
+
touch-action: manipulation;
|
139 |
+
cursor: pointer;
|
140 |
+
background-image: none;
|
141 |
+
border: 1px solid transparent;
|
142 |
+
white-space: nowrap;
|
143 |
+
padding: 10px 18px;
|
144 |
+
font-size: 15px;
|
145 |
+
line-height: 1.42857143;
|
146 |
+
border-radius: 0;
|
147 |
+
user-select: none;
|
148 |
+
}
|
149 |
+
|
150 |
+
#standalone-container {
|
151 |
+
position: relative;
|
152 |
+
max-width: 720px;
|
153 |
+
background-color: #ffffff;
|
154 |
+
color: black !important;
|
155 |
+
z-index: 1000;
|
156 |
+
}
|
157 |
+
|
158 |
+
#editor-container {
|
159 |
+
font-family: "Aref Ruqaa";
|
160 |
+
font-size: 18px;
|
161 |
+
height: 250px;
|
162 |
+
}
|
163 |
+
|
164 |
+
#toolbar-container {
|
165 |
+
font-family: "Aref Ruqaa";
|
166 |
+
display: flex;
|
167 |
+
flex-wrap: wrap;
|
168 |
+
}
|
169 |
+
|
170 |
+
#json-container {
|
171 |
+
max-width: 720px;
|
172 |
+
}
|
173 |
+
|
174 |
+
/* Set dropdown font-families */
|
175 |
+
#toolbar-container .ql-font span[data-label="Base"]::before {
|
176 |
+
font-family: "Aref Ruqaa";
|
177 |
+
}
|
178 |
+
|
179 |
+
#toolbar-container .ql-font span[data-label="Claude Monet"]::before {
|
180 |
+
font-family: "Mirza";
|
181 |
+
}
|
182 |
+
|
183 |
+
#toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
|
184 |
+
font-family: "Roboto";
|
185 |
+
}
|
186 |
+
|
187 |
+
#toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
|
188 |
+
font-family: "Comic Sans MS";
|
189 |
+
}
|
190 |
+
|
191 |
+
#toolbar-container .ql-font span[data-label="Pop Art"]::before {
|
192 |
+
font-family: "sofia";
|
193 |
+
}
|
194 |
+
|
195 |
+
#toolbar-container .ql-font span[data-label="Van Gogh"]::before {
|
196 |
+
font-family: "slabo 27px";
|
197 |
+
}
|
198 |
+
|
199 |
+
#toolbar-container .ql-font span[data-label="Pixel Art"]::before {
|
200 |
+
font-family: "inconsolata";
|
201 |
+
}
|
202 |
+
|
203 |
+
#toolbar-container .ql-font span[data-label="Rembrandt"]::before {
|
204 |
+
font-family: "ubuntu";
|
205 |
+
}
|
206 |
+
|
207 |
+
#toolbar-container .ql-font span[data-label="Cubism"]::before {
|
208 |
+
font-family: "Akronim";
|
209 |
+
}
|
210 |
+
|
211 |
+
#toolbar-container .ql-font span[data-label="Neon Art"]::before {
|
212 |
+
font-family: "Monoton";
|
213 |
+
}
|
214 |
+
|
215 |
+
/* Set content font-families */
|
216 |
+
.ql-font-mirza {
|
217 |
+
font-family: "Mirza";
|
218 |
+
}
|
219 |
+
|
220 |
+
.ql-font-roboto {
|
221 |
+
font-family: "Roboto";
|
222 |
+
}
|
223 |
+
|
224 |
+
.ql-font-cursive {
|
225 |
+
font-family: "Comic Sans MS";
|
226 |
+
}
|
227 |
+
|
228 |
+
.ql-font-sofia {
|
229 |
+
font-family: "sofia";
|
230 |
+
}
|
231 |
+
|
232 |
+
.ql-font-slabo {
|
233 |
+
font-family: "slabo 27px";
|
234 |
+
}
|
235 |
+
|
236 |
+
.ql-font-inconsolata {
|
237 |
+
font-family: "inconsolata";
|
238 |
+
}
|
239 |
+
|
240 |
+
.ql-font-ubuntu {
|
241 |
+
font-family: "ubuntu";
|
242 |
+
}
|
243 |
+
|
244 |
+
.ql-font-Akronim {
|
245 |
+
font-family: "Akronim";
|
246 |
+
}
|
247 |
+
|
248 |
+
.ql-font-Monoton {
|
249 |
+
font-family: "Monoton";
|
250 |
+
}
|
251 |
+
</style>
|
252 |
+
`;
|
253 |
+
}
|
254 |
+
async mountQuill() {
|
255 |
+
// Register the customs format with Quill
|
256 |
+
const lib = await import("https://cdn.jsdelivr.net/npm/shadow-selection-polyfill");
|
257 |
+
const getRange = lib.getRange;
|
258 |
+
|
259 |
+
const Font = Quill.import('formats/font');
|
260 |
+
Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
|
261 |
+
const Link = Quill.import('formats/link');
|
262 |
+
Link.sanitize = function (url) {
|
263 |
+
// modify url if desired
|
264 |
+
return url;
|
265 |
+
}
|
266 |
+
const SizeStyle = Quill.import('attributors/style/size');
|
267 |
+
SizeStyle.whitelist = ['10px', '18px', '32px', '50px', '64px'];
|
268 |
+
Quill.register(SizeStyle, true);
|
269 |
+
Quill.register(Link, true);
|
270 |
+
Quill.register(Font, true);
|
271 |
+
const icons = Quill.import('ui/icons');
|
272 |
+
const icon = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
|
273 |
+
icons['link'] = icon;
|
274 |
+
const editorContainer = this.shadowRoot.querySelector('#editor-container')
|
275 |
+
const toolbarContainer = this.shadowRoot.querySelector('#toolbar-container')
|
276 |
+
const myQuill = new Quill(editorContainer, {
|
277 |
+
modules: {
|
278 |
+
toolbar: {
|
279 |
+
container: toolbarContainer,
|
280 |
+
},
|
281 |
+
},
|
282 |
+
theme: 'snow'
|
283 |
+
});
|
284 |
+
const normalizeNative = (nativeRange) => {
|
285 |
+
|
286 |
+
if (nativeRange) {
|
287 |
+
const range = nativeRange;
|
288 |
+
|
289 |
+
if (range.baseNode) {
|
290 |
+
range.startContainer = nativeRange.baseNode;
|
291 |
+
range.endContainer = nativeRange.focusNode;
|
292 |
+
range.startOffset = nativeRange.baseOffset;
|
293 |
+
range.endOffset = nativeRange.focusOffset;
|
294 |
+
|
295 |
+
if (range.endOffset < range.startOffset) {
|
296 |
+
range.startContainer = nativeRange.focusNode;
|
297 |
+
range.endContainer = nativeRange.baseNode;
|
298 |
+
range.startOffset = nativeRange.focusOffset;
|
299 |
+
range.endOffset = nativeRange.baseOffset;
|
300 |
+
}
|
301 |
+
}
|
302 |
+
|
303 |
+
if (range.startContainer) {
|
304 |
+
return {
|
305 |
+
start: { node: range.startContainer, offset: range.startOffset },
|
306 |
+
end: { node: range.endContainer, offset: range.endOffset },
|
307 |
+
native: range
|
308 |
+
};
|
309 |
+
}
|
310 |
+
}
|
311 |
+
|
312 |
+
return null
|
313 |
+
};
|
314 |
+
|
315 |
+
myQuill.selection.getNativeRange = () => {
|
316 |
+
|
317 |
+
const dom = myQuill.root.getRootNode();
|
318 |
+
const selection = getRange(dom);
|
319 |
+
const range = normalizeNative(selection);
|
320 |
+
|
321 |
+
return range;
|
322 |
+
};
|
323 |
+
let fromEditor = false;
|
324 |
+
editorContainer.addEventListener("pointerup", (e) => {
|
325 |
+
fromEditor = false;
|
326 |
+
});
|
327 |
+
editorContainer.addEventListener("pointerout", (e) => {
|
328 |
+
fromEditor = false;
|
329 |
+
});
|
330 |
+
editorContainer.addEventListener("pointerdown", (e) => {
|
331 |
+
fromEditor = true;
|
332 |
+
});
|
333 |
+
|
334 |
+
document.addEventListener("selectionchange", () => {
|
335 |
+
if (fromEditor) {
|
336 |
+
myQuill.selection.update()
|
337 |
+
}
|
338 |
+
});
|
339 |
+
|
340 |
+
|
341 |
+
myQuill.on('text-change', () => {
|
342 |
+
// keep qull data inside _data to communicate with Gradio
|
343 |
+
document.querySelector("#rich-text-root")._data = myQuill.getContents()
|
344 |
+
})
|
345 |
+
return myQuill
|
346 |
+
}
|
347 |
+
}
|
348 |
+
|
349 |
+
customElements.define('rich-text-editor', RichTextEditor);
|
share_btn.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
2 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
3 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
4 |
+
</svg>"""
|
5 |
+
|
6 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin" style="color: #ffffff;" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
7 |
+
|
8 |
+
share_js = """async () => {
|
9 |
+
async function uploadFile(file){
|
10 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
11 |
+
const response = await fetch(UPLOAD_URL, {
|
12 |
+
method: 'POST',
|
13 |
+
headers: {
|
14 |
+
'Content-Type': file.type,
|
15 |
+
'X-Requested-With': 'XMLHttpRequest',
|
16 |
+
},
|
17 |
+
body: file, /// <- File inherits from Blob
|
18 |
+
});
|
19 |
+
const url = await response.text();
|
20 |
+
return url;
|
21 |
+
}
|
22 |
+
async function getInputImageFile(imageEl){
|
23 |
+
const res = await fetch(imageEl.src);
|
24 |
+
const blob = await res.blob();
|
25 |
+
const imageId = Date.now();
|
26 |
+
const fileName = `rich-text-image-${{imageId}}.png`;
|
27 |
+
return new File([blob], fileName, { type: 'image/png'});
|
28 |
+
}
|
29 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
30 |
+
const richEl = document.getElementById("rich-text-root");
|
31 |
+
const data = richEl? richEl.contentDocument.body._data : {};
|
32 |
+
const text_input = JSON.stringify(data);
|
33 |
+
const negative_prompt = gradioEl.querySelector('#negative_prompt input').value;
|
34 |
+
const seed = gradioEl.querySelector('#seed input').value;
|
35 |
+
const richTextImg = gradioEl.querySelector('#rich-text-image img');
|
36 |
+
const plainTextImg = gradioEl.querySelector('#plain-text-image img');
|
37 |
+
const text_input_obj = JSON.parse(text_input);
|
38 |
+
const plain_prompt = text_input_obj.ops.map(e=> e.insert).join('');
|
39 |
+
const linkSrc = `https://huggingface.co/spaces/songweig/rich-text-to-image?prompt=${encodeURIComponent(text_input)}`;
|
40 |
+
|
41 |
+
const titleTxt = `RT2I: ${plain_prompt.slice(0, 50)}...`;
|
42 |
+
const shareBtnEl = gradioEl.querySelector('#share-btn');
|
43 |
+
const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
|
44 |
+
const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
|
45 |
+
if(!richTextImg){
|
46 |
+
return;
|
47 |
+
};
|
48 |
+
shareBtnEl.style.pointerEvents = 'none';
|
49 |
+
shareIconEl.style.display = 'none';
|
50 |
+
loadingIconEl.style.removeProperty('display');
|
51 |
+
|
52 |
+
const richImgFile = await getInputImageFile(richTextImg);
|
53 |
+
const plainImgFile = await getInputImageFile(plainTextImg);
|
54 |
+
const richImgURL = await uploadFile(richImgFile);
|
55 |
+
const plainImgURL = await uploadFile(plainImgFile);
|
56 |
+
|
57 |
+
const descriptionMd = `
|
58 |
+
### Plain Prompt
|
59 |
+
${plain_prompt}
|
60 |
+
|
61 |
+
🔗 Shareable Link + Params: [here](${linkSrc})
|
62 |
+
|
63 |
+
### Rich Tech Image
|
64 |
+
<img src="${richImgURL}">
|
65 |
+
|
66 |
+
### Plain Text Image
|
67 |
+
<img src="${plainImgURL}">
|
68 |
+
|
69 |
+
`;
|
70 |
+
const params = new URLSearchParams({
|
71 |
+
title: titleTxt,
|
72 |
+
description: descriptionMd,
|
73 |
+
});
|
74 |
+
const paramsStr = params.toString();
|
75 |
+
window.open(`https://huggingface.co/spaces/songweig/rich-text-to-image/discussions/new?${paramsStr}`, '_blank');
|
76 |
+
shareBtnEl.style.removeProperty('pointer-events');
|
77 |
+
shareIconEl.style.removeProperty('display');
|
78 |
+
loadingIconEl.style.display = 'none';
|
79 |
+
}"""
|
80 |
+
|
81 |
+
css = """
|
82 |
+
#share-btn-container {
|
83 |
+
display: flex;
|
84 |
+
padding-left: 0.5rem !important;
|
85 |
+
padding-right: 0.5rem !important;
|
86 |
+
background-color: #000000;
|
87 |
+
justify-content: center;
|
88 |
+
align-items: center;
|
89 |
+
border-radius: 9999px !important;
|
90 |
+
width: 13rem;
|
91 |
+
margin-top: 10px;
|
92 |
+
margin-left: auto;
|
93 |
+
flex: unset !important;
|
94 |
+
}
|
95 |
+
#share-btn {
|
96 |
+
all: initial;
|
97 |
+
color: #ffffff;
|
98 |
+
font-weight: 600;
|
99 |
+
cursor: pointer;
|
100 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
101 |
+
margin-left: 0.5rem !important;
|
102 |
+
padding-top: 0.25rem !important;
|
103 |
+
padding-bottom: 0.25rem !important;
|
104 |
+
right:0;
|
105 |
+
}
|
106 |
+
#share-btn * {
|
107 |
+
all: unset !important;
|
108 |
+
}
|
109 |
+
#share-btn-container div:nth-child(-n+2){
|
110 |
+
width: auto !important;
|
111 |
+
min-height: 0px !important;
|
112 |
+
}
|
113 |
+
#share-btn-container .wrap {
|
114 |
+
display: none !important;
|
115 |
+
}
|
116 |
+
"""
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/attention_utils.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import matplotlib as mpl
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from utils.richtext_utils import seed_everything
|
10 |
+
from sklearn.cluster import SpectralClustering
|
11 |
+
|
12 |
+
SelfAttentionLayers = [
|
13 |
+
'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
|
14 |
+
'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
|
15 |
+
'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
|
16 |
+
'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
|
17 |
+
'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
|
18 |
+
'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
|
19 |
+
'mid_block.attentions.0.transformer_blocks.0.attn1',
|
20 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
|
21 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
|
22 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
|
23 |
+
'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
|
24 |
+
'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
|
25 |
+
'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
|
26 |
+
'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
|
27 |
+
'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
|
28 |
+
'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
CrossAttentionLayers = [
|
33 |
+
# 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
|
34 |
+
# 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
|
35 |
+
'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
|
36 |
+
# 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
|
37 |
+
'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
|
38 |
+
'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
|
39 |
+
'mid_block.attentions.0.transformer_blocks.0.attn2',
|
40 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
|
41 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
|
42 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
|
43 |
+
# 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
|
44 |
+
'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
|
45 |
+
# 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
|
46 |
+
# 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
|
47 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
|
48 |
+
# 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
|
49 |
+
]
|
50 |
+
|
51 |
+
|
52 |
+
def split_attention_maps_over_steps(attention_maps):
|
53 |
+
r"""Function for splitting attention maps over steps.
|
54 |
+
Args:
|
55 |
+
attention_maps (dict): Dictionary of attention maps.
|
56 |
+
sampler_order (int): Order of the sampler.
|
57 |
+
"""
|
58 |
+
# This function splits attention maps into unconditional and conditional score and over steps
|
59 |
+
|
60 |
+
attention_maps_cond = dict() # Maps corresponding to conditional score
|
61 |
+
attention_maps_uncond = dict() # Maps corresponding to unconditional score
|
62 |
+
|
63 |
+
for layer in attention_maps.keys():
|
64 |
+
|
65 |
+
for step_num in range(len(attention_maps[layer])):
|
66 |
+
if step_num not in attention_maps_cond:
|
67 |
+
attention_maps_cond[step_num] = dict()
|
68 |
+
attention_maps_uncond[step_num] = dict()
|
69 |
+
|
70 |
+
attention_maps_uncond[step_num].update(
|
71 |
+
{layer: attention_maps[layer][step_num][:1]})
|
72 |
+
attention_maps_cond[step_num].update(
|
73 |
+
{layer: attention_maps[layer][step_num][1:2]})
|
74 |
+
|
75 |
+
return attention_maps_cond, attention_maps_uncond
|
76 |
+
|
77 |
+
|
78 |
+
def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
|
79 |
+
atten_names = ['presoftmax', 'postsoftmax', 'postsoftmax_erosion']
|
80 |
+
for i, attn_map in enumerate(atten_map_list):
|
81 |
+
n_obj = len(attn_map)
|
82 |
+
plt.figure()
|
83 |
+
plt.clf()
|
84 |
+
|
85 |
+
fig, axs = plt.subplots(
|
86 |
+
ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))
|
87 |
+
|
88 |
+
fig.set_figheight(3)
|
89 |
+
fig.set_figwidth(3*n_obj+0.1)
|
90 |
+
|
91 |
+
cmap = plt.get_cmap('OrRd')
|
92 |
+
|
93 |
+
vmax = 0
|
94 |
+
vmin = 1
|
95 |
+
for tid in range(n_obj):
|
96 |
+
attention_map_cur = attn_map[tid]
|
97 |
+
vmax = max(vmax, float(attention_map_cur.max()))
|
98 |
+
vmin = min(vmin, float(attention_map_cur.min()))
|
99 |
+
|
100 |
+
for tid in range(n_obj):
|
101 |
+
sns.heatmap(
|
102 |
+
attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
|
103 |
+
cmap=cmap, vmin=vmin, vmax=vmax
|
104 |
+
)
|
105 |
+
axs[tid].set_axis_off()
|
106 |
+
|
107 |
+
if tokens_vis is not None:
|
108 |
+
if tid == n_obj-1:
|
109 |
+
axs_xlabel = 'other tokens'
|
110 |
+
else:
|
111 |
+
axs_xlabel = ''
|
112 |
+
for token_id in obj_tokens[tid]:
|
113 |
+
axs_xlabel += ' ' + tokens_vis[token_id.item() -
|
114 |
+
1][:-len('</w>')]
|
115 |
+
axs[tid].set_title(axs_xlabel)
|
116 |
+
|
117 |
+
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
118 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
119 |
+
fig.colorbar(sm, cax=axs[-1])
|
120 |
+
canvas = fig.canvas
|
121 |
+
canvas.draw()
|
122 |
+
width, height = canvas.get_width_height()
|
123 |
+
img = np.frombuffer(canvas.tostring_rgb(),
|
124 |
+
dtype='uint8').reshape((height, width, 3))
|
125 |
+
|
126 |
+
fig.tight_layout()
|
127 |
+
plt.close()
|
128 |
+
return img
|
129 |
+
|
130 |
+
|
131 |
+
def get_token_maps_deprecated(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None):
|
132 |
+
r"""Function to visualize attention maps.
|
133 |
+
Args:
|
134 |
+
save_dir (str): Path to save attention maps
|
135 |
+
batch_size (int): Batch size
|
136 |
+
sampler_order (int): Sampler order
|
137 |
+
"""
|
138 |
+
|
139 |
+
# Split attention maps over steps
|
140 |
+
attention_maps_cond, _ = split_attention_maps_over_steps(
|
141 |
+
attention_maps
|
142 |
+
)
|
143 |
+
|
144 |
+
nsteps = len(attention_maps_cond)
|
145 |
+
hw_ori = width * height
|
146 |
+
|
147 |
+
attention_maps = []
|
148 |
+
for obj_token in obj_tokens:
|
149 |
+
attention_maps.append([])
|
150 |
+
|
151 |
+
for step_num in range(nsteps):
|
152 |
+
attention_maps_cur = attention_maps_cond[step_num]
|
153 |
+
|
154 |
+
for layer in attention_maps_cur.keys():
|
155 |
+
if step_num < 10 or layer not in CrossAttentionLayers:
|
156 |
+
continue
|
157 |
+
|
158 |
+
attention_ind = attention_maps_cur[layer].cpu()
|
159 |
+
|
160 |
+
# Attention maps are of shape [batch_size, nkeys, 77]
|
161 |
+
# since they are averaged out while collecting from hooks to save memory.
|
162 |
+
# Now split the heads from batch dimension
|
163 |
+
bs, hw, nclip = attention_ind.shape
|
164 |
+
down_ratio = np.sqrt(hw_ori // hw)
|
165 |
+
width_cur = int(width // down_ratio)
|
166 |
+
height_cur = int(height // down_ratio)
|
167 |
+
attention_ind = attention_ind.reshape(
|
168 |
+
bs, height_cur, width_cur, nclip)
|
169 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
170 |
+
if obj_token[0] == -1:
|
171 |
+
attention_map_prev = torch.stack(
|
172 |
+
[attention_maps[i][-1] for i in range(obj_id)]).sum(0)
|
173 |
+
attention_maps[obj_id].append(
|
174 |
+
attention_map_prev.max()-attention_map_prev)
|
175 |
+
else:
|
176 |
+
obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
|
177 |
+
0].permute([3, 0, 1, 2])
|
178 |
+
obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
|
179 |
+
interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
|
180 |
+
attention_maps[obj_id].append(obj_attention_map)
|
181 |
+
|
182 |
+
# average attention maps over steps
|
183 |
+
attention_maps_averaged = []
|
184 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
185 |
+
if obj_id == len(obj_tokens) - 1:
|
186 |
+
attention_maps_averaged.append(
|
187 |
+
torch.cat(attention_maps[obj_id]).mean(0))
|
188 |
+
else:
|
189 |
+
attention_maps_averaged.append(
|
190 |
+
torch.cat(attention_maps[obj_id]).mean(0))
|
191 |
+
|
192 |
+
# normalize attention maps into [0, 1]
|
193 |
+
attention_maps_averaged_normalized = []
|
194 |
+
attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
|
195 |
+
for obj_id, obj_token in enumerate(obj_tokens):
|
196 |
+
attention_maps_averaged_normalized.append(
|
197 |
+
attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
|
198 |
+
|
199 |
+
# softmax
|
200 |
+
attention_maps_averaged_normalized = (
|
201 |
+
torch.cat(attention_maps_averaged)/0.001).softmax(0)
|
202 |
+
attention_maps_averaged_normalized = [
|
203 |
+
attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
|
204 |
+
|
205 |
+
token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
|
206 |
+
obj_tokens, save_dir, seed, tokens_vis)
|
207 |
+
attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
|
208 |
+
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
|
209 |
+
return attention_maps_averaged_normalized, token_maps_vis
|
210 |
+
|
211 |
+
|
212 |
+
def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
|
213 |
+
preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
|
214 |
+
r"""Function to visualize attention maps.
|
215 |
+
Args:
|
216 |
+
save_dir (str): Path to save attention maps
|
217 |
+
batch_size (int): Batch size
|
218 |
+
sampler_order (int): Sampler order
|
219 |
+
"""
|
220 |
+
|
221 |
+
# create the segmentation mask using self-attention maps
|
222 |
+
resolution = 32
|
223 |
+
attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
|
224 |
+
for attn_map in selfattn_maps.values():
|
225 |
+
resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
|
226 |
+
if resolution_map != resolution:
|
227 |
+
continue
|
228 |
+
attn_map = attn_map.reshape(
|
229 |
+
1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2])
|
230 |
+
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
|
231 |
+
mode='bicubic', antialias=True)
|
232 |
+
attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
|
233 |
+
1, resolution**2, resolution_map**2))
|
234 |
+
attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
|
235 |
+
for v in attn_maps_1024.values() if len(v) > 0], -1).numpy()
|
236 |
+
if save_attn:
|
237 |
+
print('saving self-attention maps...', attn_maps_1024.shape)
|
238 |
+
torch.save(torch.from_numpy(attn_maps_1024),
|
239 |
+
'results/maps/selfattn_maps.pth')
|
240 |
+
seed_everything(seed)
|
241 |
+
sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
|
242 |
+
assign_labels='kmeans')
|
243 |
+
clusters = sc.fit_predict(attn_maps_1024)
|
244 |
+
clusters = clusters.reshape(resolution, resolution)
|
245 |
+
fig = plt.figure()
|
246 |
+
plt.imshow(clusters)
|
247 |
+
plt.axis('off')
|
248 |
+
if return_vis:
|
249 |
+
canvas = fig.canvas
|
250 |
+
canvas.draw()
|
251 |
+
cav_width, cav_height = canvas.get_width_height()
|
252 |
+
segments_vis = np.frombuffer(canvas.tostring_rgb(),
|
253 |
+
dtype='uint8').reshape((cav_height, cav_width, 3))
|
254 |
+
|
255 |
+
plt.close()
|
256 |
+
|
257 |
+
# label the segmentation mask using cross-attention maps
|
258 |
+
cross_attn_maps_1024 = []
|
259 |
+
for attn_map in crossattn_maps.values():
|
260 |
+
resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
|
261 |
+
attn_map = attn_map.reshape(
|
262 |
+
1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2])
|
263 |
+
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
|
264 |
+
mode='bicubic', antialias=True)
|
265 |
+
cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
|
266 |
+
|
267 |
+
cross_attn_maps_1024 = torch.cat(
|
268 |
+
cross_attn_maps_1024).mean(0).cpu().numpy()
|
269 |
+
if save_attn:
|
270 |
+
print('saving cross-attention maps...', cross_attn_maps_1024.shape)
|
271 |
+
torch.save(torch.from_numpy(cross_attn_maps_1024),
|
272 |
+
'results/maps/crossattn_maps.pth')
|
273 |
+
normalized_span_maps = []
|
274 |
+
for token_ids in obj_tokens:
|
275 |
+
span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
|
276 |
+
normalized_span_map = np.zeros_like(span_token_maps)
|
277 |
+
for i in range(span_token_maps.shape[-1]):
|
278 |
+
curr_noun_map = span_token_maps[:, :, i]
|
279 |
+
normalized_span_map[:, :, i] = (
|
280 |
+
curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
|
281 |
+
normalized_span_maps.append(normalized_span_map)
|
282 |
+
foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
|
283 |
+
) for normalized_span_map in normalized_span_maps]
|
284 |
+
background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze()
|
285 |
+
for c in range(num_segments):
|
286 |
+
cluster_mask = np.zeros_like(clusters)
|
287 |
+
cluster_mask[clusters == c] = 1.
|
288 |
+
is_foreground = False
|
289 |
+
for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens):
|
290 |
+
score_maps = [cluster_mask * normalized_span_map[:, :, i]
|
291 |
+
for i in range(len(token_ids))]
|
292 |
+
scores = [score_map.sum() / cluster_mask.sum()
|
293 |
+
for score_map in score_maps]
|
294 |
+
if max(scores) > segment_threshold:
|
295 |
+
foreground_nouns_map += cluster_mask
|
296 |
+
is_foreground = True
|
297 |
+
if not is_foreground:
|
298 |
+
background_map += cluster_mask
|
299 |
+
foreground_token_maps.append(background_map)
|
300 |
+
|
301 |
+
# resize the token maps and visualization
|
302 |
+
resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze(
|
303 |
+
0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1)
|
304 |
+
|
305 |
+
resized_token_maps = resized_token_maps / \
|
306 |
+
(resized_token_maps.sum(0, True)+1e-8)
|
307 |
+
resized_token_maps = [token_map.unsqueeze(
|
308 |
+
0) for token_map in resized_token_maps]
|
309 |
+
foreground_token_maps = [token_map[None, :, :]
|
310 |
+
for token_map in foreground_token_maps]
|
311 |
+
token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
|
312 |
+
save_dir, seed, tokens_vis)
|
313 |
+
resized_token_maps = [token_map.unsqueeze(1).repeat(
|
314 |
+
[1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
|
315 |
+
if return_vis:
|
316 |
+
return resized_token_maps, segments_vis, token_maps_vis
|
317 |
+
else:
|
318 |
+
return resized_token_maps
|
utils/richtext_utils.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
COLORS = {
|
8 |
+
'brown': [165, 42, 42],
|
9 |
+
'red': [255, 0, 0],
|
10 |
+
'pink': [253, 108, 158],
|
11 |
+
'orange': [255, 165, 0],
|
12 |
+
'yellow': [255, 255, 0],
|
13 |
+
'purple': [128, 0, 128],
|
14 |
+
'green': [0, 128, 0],
|
15 |
+
'blue': [0, 0, 255],
|
16 |
+
'white': [255, 255, 255],
|
17 |
+
'gray': [128, 128, 128],
|
18 |
+
'black': [0, 0, 0],
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def seed_everything(seed):
|
23 |
+
random.seed(seed)
|
24 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed(seed)
|
28 |
+
|
29 |
+
|
30 |
+
def hex_to_rgb(hex_string, return_nearest_color=False):
|
31 |
+
r"""
|
32 |
+
Covert Hex triplet to RGB triplet.
|
33 |
+
"""
|
34 |
+
# Remove '#' symbol if present
|
35 |
+
hex_string = hex_string.lstrip('#')
|
36 |
+
# Convert hex values to integers
|
37 |
+
red = int(hex_string[0:2], 16)
|
38 |
+
green = int(hex_string[2:4], 16)
|
39 |
+
blue = int(hex_string[4:6], 16)
|
40 |
+
rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
|
41 |
+
if return_nearest_color:
|
42 |
+
nearest_color = find_nearest_color(rgb)
|
43 |
+
return rgb.cuda(), nearest_color
|
44 |
+
return rgb.cuda()
|
45 |
+
|
46 |
+
|
47 |
+
def find_nearest_color(rgb):
|
48 |
+
r"""
|
49 |
+
Find the nearest neighbor color given the RGB value.
|
50 |
+
"""
|
51 |
+
if isinstance(rgb, list) or isinstance(rgb, tuple):
|
52 |
+
rgb = torch.FloatTensor(rgb)[None, :, None, None]/255.
|
53 |
+
color_distance = torch.FloatTensor([np.linalg.norm(
|
54 |
+
rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()])
|
55 |
+
nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()]
|
56 |
+
return nearest_color
|
57 |
+
|
58 |
+
|
59 |
+
def font2style(font):
|
60 |
+
r"""
|
61 |
+
Convert the font name to the style name.
|
62 |
+
"""
|
63 |
+
return {'mirza': 'Claud Monet, impressionism, oil on canvas',
|
64 |
+
'roboto': 'Ukiyoe',
|
65 |
+
'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq',
|
66 |
+
'sofia': 'Pop Art, masterpiece, andy warhol',
|
67 |
+
'slabo': 'Vincent Van Gogh',
|
68 |
+
'inconsolata': 'Pixel Art, 8 bits, 16 bits',
|
69 |
+
'ubuntu': 'Rembrandt',
|
70 |
+
'Monoton': 'neon art, colorful light, highly details, octane render',
|
71 |
+
'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
|
72 |
+
|
73 |
+
|
74 |
+
def parse_json(json_str):
|
75 |
+
r"""
|
76 |
+
Convert the JSON string to attributes.
|
77 |
+
"""
|
78 |
+
# initialze region-base attributes.
|
79 |
+
base_text_prompt = ''
|
80 |
+
style_text_prompts = []
|
81 |
+
footnote_text_prompts = []
|
82 |
+
footnote_target_tokens = []
|
83 |
+
color_text_prompts = []
|
84 |
+
color_rgbs = []
|
85 |
+
color_names = []
|
86 |
+
size_text_prompts_and_sizes = []
|
87 |
+
|
88 |
+
# parse the attributes from JSON.
|
89 |
+
prev_style = None
|
90 |
+
prev_color_rgb = None
|
91 |
+
use_grad_guidance = False
|
92 |
+
for span in json_str['ops']:
|
93 |
+
text_prompt = span['insert'].rstrip('\n')
|
94 |
+
base_text_prompt += span['insert'].rstrip('\n')
|
95 |
+
if text_prompt == ' ':
|
96 |
+
continue
|
97 |
+
if 'attributes' in span:
|
98 |
+
if 'font' in span['attributes']:
|
99 |
+
style = font2style(span['attributes']['font'])
|
100 |
+
if prev_style == style:
|
101 |
+
prev_text_prompt = style_text_prompts[-1].split('in the style of')[
|
102 |
+
0]
|
103 |
+
style_text_prompts[-1] = prev_text_prompt + \
|
104 |
+
' ' + text_prompt + f' in the style of {style}'
|
105 |
+
else:
|
106 |
+
style_text_prompts.append(
|
107 |
+
text_prompt + f' in the style of {style}')
|
108 |
+
prev_style = style
|
109 |
+
else:
|
110 |
+
prev_style = None
|
111 |
+
if 'link' in span['attributes']:
|
112 |
+
footnote_text_prompts.append(span['attributes']['link'])
|
113 |
+
footnote_target_tokens.append(text_prompt)
|
114 |
+
font_size = 1
|
115 |
+
if 'size' in span['attributes'] and 'strike' not in span['attributes']:
|
116 |
+
font_size = float(span['attributes']['size'][:-2])/3.
|
117 |
+
elif 'size' in span['attributes'] and 'strike' in span['attributes']:
|
118 |
+
font_size = -float(span['attributes']['size'][:-2])/3.
|
119 |
+
elif 'size' not in span['attributes'] and 'strike' not in span['attributes']:
|
120 |
+
font_size = 1
|
121 |
+
if 'color' in span['attributes']:
|
122 |
+
use_grad_guidance = True
|
123 |
+
color_rgb, nearest_color = hex_to_rgb(
|
124 |
+
span['attributes']['color'], True)
|
125 |
+
if prev_color_rgb == color_rgb:
|
126 |
+
prev_text_prompt = color_text_prompts[-1]
|
127 |
+
color_text_prompts[-1] = prev_text_prompt + \
|
128 |
+
' ' + text_prompt
|
129 |
+
else:
|
130 |
+
color_rgbs.append(color_rgb)
|
131 |
+
color_names.append(nearest_color)
|
132 |
+
color_text_prompts.append(text_prompt)
|
133 |
+
if font_size != 1:
|
134 |
+
size_text_prompts_and_sizes.append([text_prompt, font_size])
|
135 |
+
return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
|
136 |
+
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance
|
137 |
+
|
138 |
+
|
139 |
+
def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts,
|
140 |
+
footnote_target_tokens, color_text_prompts, color_names):
|
141 |
+
r"""
|
142 |
+
Algorithm 1 in the paper.
|
143 |
+
"""
|
144 |
+
region_text_prompts = []
|
145 |
+
region_target_token_ids = []
|
146 |
+
base_tokens = model.tokenizer._tokenize(base_text_prompt)
|
147 |
+
# process the style text prompt
|
148 |
+
for text_prompt in style_text_prompts:
|
149 |
+
region_text_prompts.append(text_prompt)
|
150 |
+
region_target_token_ids.append([])
|
151 |
+
style_tokens = model.tokenizer._tokenize(
|
152 |
+
text_prompt.split('in the style of')[0])
|
153 |
+
for style_token in style_tokens:
|
154 |
+
region_target_token_ids[-1].append(
|
155 |
+
base_tokens.index(style_token)+1)
|
156 |
+
|
157 |
+
# process the complementary text prompt
|
158 |
+
for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens):
|
159 |
+
region_target_token_ids.append([])
|
160 |
+
region_text_prompts.append(footnote_text_prompt)
|
161 |
+
style_tokens = model.tokenizer._tokenize(text_prompt)
|
162 |
+
for style_token in style_tokens:
|
163 |
+
region_target_token_ids[-1].append(
|
164 |
+
base_tokens.index(style_token)+1)
|
165 |
+
|
166 |
+
# process the color text prompt
|
167 |
+
for color_text_prompt, color_name in zip(color_text_prompts, color_names):
|
168 |
+
region_target_token_ids.append([])
|
169 |
+
region_text_prompts.append(color_name+' '+color_text_prompt)
|
170 |
+
style_tokens = model.tokenizer._tokenize(color_text_prompt)
|
171 |
+
for style_token in style_tokens:
|
172 |
+
region_target_token_ids[-1].append(
|
173 |
+
base_tokens.index(style_token)+1)
|
174 |
+
|
175 |
+
# process the remaining tokens without any attributes
|
176 |
+
region_text_prompts.append(base_text_prompt)
|
177 |
+
region_target_token_ids_all = [
|
178 |
+
id for ids in region_target_token_ids for id in ids]
|
179 |
+
target_token_ids_rest = [id for id in range(
|
180 |
+
1, len(base_tokens)+1) if id not in region_target_token_ids_all]
|
181 |
+
region_target_token_ids.append(target_token_ids_rest)
|
182 |
+
|
183 |
+
region_target_token_ids = [torch.LongTensor(
|
184 |
+
obj_token_id) for obj_token_id in region_target_token_ids]
|
185 |
+
return region_text_prompts, region_target_token_ids, base_tokens
|
186 |
+
|
187 |
+
|
188 |
+
def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes):
|
189 |
+
r"""
|
190 |
+
Control the token impact using font sizes.
|
191 |
+
"""
|
192 |
+
word_pos = []
|
193 |
+
font_sizes = []
|
194 |
+
for text_prompt, font_size in size_text_prompts_and_sizes:
|
195 |
+
size_tokens = model.tokenizer._tokenize(text_prompt)
|
196 |
+
for size_token in size_tokens:
|
197 |
+
word_pos.append(base_tokens.index(size_token)+1)
|
198 |
+
font_sizes.append(font_size)
|
199 |
+
if len(word_pos) > 0:
|
200 |
+
word_pos = torch.LongTensor(word_pos).cuda()
|
201 |
+
font_sizes = torch.FloatTensor(font_sizes).cuda()
|
202 |
+
else:
|
203 |
+
word_pos = None
|
204 |
+
font_sizes = None
|
205 |
+
text_format_dict = {
|
206 |
+
'word_pos': word_pos,
|
207 |
+
'font_size': font_sizes,
|
208 |
+
}
|
209 |
+
return text_format_dict
|
210 |
+
|
211 |
+
|
212 |
+
def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict,
|
213 |
+
guidance_start_step=999, color_guidance_weight=1):
|
214 |
+
r"""
|
215 |
+
Control the token impact using font sizes.
|
216 |
+
"""
|
217 |
+
color_target_token_ids = []
|
218 |
+
for text_prompt in color_text_prompts:
|
219 |
+
color_target_token_ids.append([])
|
220 |
+
color_tokens = model.tokenizer._tokenize(text_prompt)
|
221 |
+
for color_token in color_tokens:
|
222 |
+
color_target_token_ids[-1].append(base_tokens.index(color_token)+1)
|
223 |
+
color_target_token_ids_all = [
|
224 |
+
id for ids in color_target_token_ids for id in ids]
|
225 |
+
color_target_token_ids_rest = [id for id in range(
|
226 |
+
1, len(base_tokens)+1) if id not in color_target_token_ids_all]
|
227 |
+
color_target_token_ids.append(color_target_token_ids_rest)
|
228 |
+
color_target_token_ids = [torch.LongTensor(
|
229 |
+
obj_token_id) for obj_token_id in color_target_token_ids]
|
230 |
+
|
231 |
+
text_format_dict['target_RGB'] = color_rgbs
|
232 |
+
text_format_dict['guidance_start_step'] = guidance_start_step
|
233 |
+
text_format_dict['color_guidance_weight'] = color_guidance_weight
|
234 |
+
return text_format_dict, color_target_token_ids
|